parameter_generator.py 14.4 KB
Newer Older
1
# coding: utf-8
2
3
4
"""Helper script for generating config file and parameters list.

This script generates LightGBM/src/io/config_auto.cpp file
5
6
7
8
with list of all parameters, aliases table and other routines
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
"""
9
import re
10
from collections import defaultdict
11
from pathlib import Path
12
from typing import Dict, List, Tuple
Guolin Ke's avatar
Guolin Ke committed
13
14


15
16
17
def get_parameter_infos(
    config_hpp: Path
) -> Tuple[List[Tuple[str, int]], List[List[Dict[str, List]]]]:
18
19
20
21
    """Parse config header file.

    Parameters
    ----------
22
    config_hpp : pathlib.Path
23
24
25
26
27
28
29
        Path to the config header file.

    Returns
    -------
    infos : tuple
        Tuple with names and content of sections.
    """
Guolin Ke's avatar
Guolin Ke committed
30
31
    is_inparameter = False
    cur_key = None
32
    key_lvl = 0
33
    cur_info: Dict[str, List] = {}
Guolin Ke's avatar
Guolin Ke committed
34
    keys = []
35
    member_infos: List[List[Dict[str, List]]] = []
Guolin Ke's avatar
Guolin Ke committed
36
37
    with open(config_hpp) as config_hpp_file:
        for line in config_hpp_file:
38
39
            if line.strip() in {"#ifndef __NVCC__", "#endif  // __NVCC__"}:
                continue
Guolin Ke's avatar
Guolin Ke committed
40
41
42
            if "#pragma region Parameters" in line:
                is_inparameter = True
            elif "#pragma region" in line and "Parameters" in line:
43
                key_lvl += 1
Guolin Ke's avatar
Guolin Ke committed
44
                cur_key = line.split("region")[1].strip()
45
                keys.append((cur_key, key_lvl))
Guolin Ke's avatar
Guolin Ke committed
46
47
                member_infos.append([])
            elif '#pragma endregion' in line:
48
                key_lvl -= 1
Guolin Ke's avatar
Guolin Ke committed
49
50
51
52
53
54
55
                if cur_key is not None:
                    cur_key = None
                elif is_inparameter:
                    is_inparameter = False
            elif cur_key is not None:
                line = line.strip()
                if line.startswith("//"):
56
57
58
                    key, _, val = line[2:].partition("=")
                    key = key.strip()
                    val = val.strip()
Guolin Ke's avatar
Guolin Ke committed
59
                    if key not in cur_info:
60
                        if key == "descl2" and "desc" not in cur_info:
Guolin Ke's avatar
Guolin Ke committed
61
                            cur_info["desc"] = []
62
                        elif key != "descl2":
Guolin Ke's avatar
Guolin Ke committed
63
64
                            cur_info[key] = []
                    if key == "desc":
65
                        cur_info["desc"].append(("l1", val))
Guolin Ke's avatar
Guolin Ke committed
66
                    elif key == "descl2":
67
                        cur_info["desc"].append(("l2", val))
Guolin Ke's avatar
Guolin Ke committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
                    else:
                        cur_info[key].append(val)
                elif line:
                    has_eqsgn = False
                    tokens = line.split("=")
                    if len(tokens) == 2:
                        if "default" not in cur_info:
                            cur_info["default"] = [tokens[1][:-1].strip()]
                        has_eqsgn = True
                    tokens = line.split()
                    cur_info["inner_type"] = [tokens[0].strip()]
                    if "name" not in cur_info:
                        if has_eqsgn:
                            cur_info["name"] = [tokens[1].strip()]
                        else:
                            cur_info["name"] = [tokens[1][:-1].strip()]
                    member_infos[-1].append(cur_info)
                    cur_info = {}
86

87
    return keys, member_infos
Guolin Ke's avatar
Guolin Ke committed
88
89


90
91
92
def get_names(
    infos: List[List[Dict[str, List]]]
) -> List[str]:
93
94
95
96
97
98
99
100
101
102
103
104
    """Get names of all parameters.

    Parameters
    ----------
    infos : list
        Content of the config header file.

    Returns
    -------
    names : list
        Names of all parameters.
    """
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
111
    names = []
    for x in infos:
        for y in x:
            names.append(y["name"][0])
    return names


112
113
114
def get_alias(
    infos: List[List[Dict[str, List]]]
) -> List[Tuple[str, str]]:
115
116
117
118
119
120
121
122
123
124
125
126
    """Get aliases of all parameters.

    Parameters
    ----------
    infos : list
        Content of the config header file.

    Returns
    -------
    pairs : list
        List of tuples (param alias, param name).
    """
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
133
    pairs = []
    for x in infos:
        for y in x:
            if "alias" in y:
                name = y["name"][0]
                alias = y["alias"][0].split(',')
                for name2 in alias:
134
                    pairs.append((name2.strip(), name))
Guolin Ke's avatar
Guolin Ke committed
135
136
137
    return pairs


138
139
140
141
def parse_check(
    check: str,
    reverse: bool = False
) -> Tuple[str, str]:
142
143
144
145
    """Parse the constraint.

    Parameters
    ----------
146
    check : str
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        String representation of the constraint.
    reverse : bool, optional (default=False)
        Whether to reverse the sign of the constraint.

    Returns
    -------
    pair : tuple
        Parsed constraint in the form of tuple (value, sign).
    """
    try:
        idx = 1
        float(check[idx:])
    except ValueError:
        idx = 2
        float(check[idx:])
    if reverse:
        reversed_sign = {'<': '>', '>': '<', '<=': '>=', '>=': '<='}
        return check[idx:], reversed_sign[check[:idx]]
    else:
        return check[idx:], check[:idx]


169
170
171
172
173
def set_one_var_from_string(
    name: str,
    param_type: str,
    checks: List[str]
) -> str:
174
175
176
177
    """Construct code for auto config file for one param value.

    Parameters
    ----------
178
    name : str
179
        Name of the parameter.
180
    param_type : str
181
182
183
184
185
186
        Type of the parameter.
    checks : list
        Constraints of the parameter.

    Returns
    -------
187
    ret : str
188
189
        Lines of auto config file with getting and checks of one parameter value.
    """
Guolin Ke's avatar
Guolin Ke committed
190
191
    ret = ""
    univar_mapper = {"int": "GetInt", "double": "GetDouble", "bool": "GetBool", "std::string": "GetString"}
192
    if "vector" not in param_type:
193
        ret += f'  {univar_mapper[param_type]}(params, "{name}", &{name});\n'
Guolin Ke's avatar
Guolin Ke committed
194
        if len(checks) > 0:
195
            check_mapper = {"<": "LT", ">": "GT", "<=": "LE", ">=": "GE"}
Guolin Ke's avatar
Guolin Ke committed
196
            for check in checks:
197
                value, sign = parse_check(check)
198
                ret += f"  CHECK_{check_mapper[sign]}({name}, {value});\n"
Guolin Ke's avatar
Guolin Ke committed
199
200
        ret += "\n"
    else:
201
        ret += f'  if (GetString(params, "{name}", &tmp_str)) {{\n'
202
        type2 = param_type.split("<")[1][:-1]
Guolin Ke's avatar
Guolin Ke committed
203
        if type2 == "std::string":
204
            ret += f"    {name} = Common::Split(tmp_str.c_str(), ',');\n"
Guolin Ke's avatar
Guolin Ke committed
205
        else:
206
            ret += f"    {name} = Common::StringToArray<{type2}>(tmp_str, ',');\n"
Guolin Ke's avatar
Guolin Ke committed
207
208
209
210
        ret += "  }\n\n"
    return ret


211
212
213
214
215
def gen_parameter_description(
    sections: List[Tuple[str, int]],
    descriptions: List[List[Dict[str, List]]],
    params_rst: Path
) -> None:
216
    """Write descriptions of parameters to the documentation file.
217

218
219
220
221
222
223
    Parameters
    ----------
    sections : list
        Names of parameters sections.
    descriptions : list
        Structured descriptions of parameters.
224
    params_rst : pathlib.Path
225
226
        Path to the file with parameters documentation.
    """
227
    params_to_write = []
228
229
230
    lvl_mapper = {1: '-', 2: '~'}
    for (section_name, section_lvl), section_params in zip(sections, descriptions):
        heading_sign = lvl_mapper[section_lvl]
231
        params_to_write.append(f'{section_name}\n{heading_sign * len(section_name)}')
232
233
234
235
236
237
238
        for param_desc in section_params:
            name = param_desc['name'][0]
            default_raw = param_desc['default'][0]
            default = default_raw.strip('"') if len(default_raw.strip('"')) > 0 else default_raw
            param_type = param_desc.get('type', param_desc['inner_type'])[0].split(':')[-1].split('<')[-1].strip('>')
            options = param_desc.get('options', [])
            if len(options) > 0:
239
240
                opts = '``, ``'.join([x.strip() for x in options[0].split(',')])
                options_str = f', options: ``{opts}``'
241
242
243
244
            else:
                options_str = ''
            aliases = param_desc.get('alias', [])
            if len(aliases) > 0:
245
246
                aliases_joined = '``, ``'.join([x.strip() for x in aliases[0].split(',')])
                aliases_str = f', aliases: ``{aliases_joined}``'
247
248
249
250
251
252
253
            else:
                aliases_str = ''
            checks = sorted(param_desc.get('check', []))
            checks_len = len(checks)
            if checks_len > 1:
                number1, sign1 = parse_check(checks[0])
                number2, sign2 = parse_check(checks[1], reverse=True)
254
                checks_str = f', constraints: ``{number2} {sign2} {name} {sign1} {number1}``'
255
256
            elif checks_len == 1:
                number, sign = parse_check(checks[0])
257
                checks_str = f', constraints: ``{name} {sign} {number}``'
258
259
            else:
                checks_str = ''
260
            main_desc = f'-  ``{name}`` :raw-html:`<a id="{name}" title="Permalink to this parameter" href="#{name}">&#x1F517;&#xFE0E;</a>`, default = ``{default}``, type = {param_type}{options_str}{aliases_str}{checks_str}'
261
            params_to_write.append(main_desc)
262
            params_to_write.extend([f"{' ' * 3 * int(desc[0][-1])}-  {desc[1]}" for desc in param_desc['desc']])
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    with open(params_rst) as original_params_file:
        all_lines = original_params_file.read()
        before, start_sep, _ = all_lines.partition('.. start params list\n\n')
        _, end_sep, after = all_lines.partition('\n\n.. end params list')

    with open(params_rst, "w") as new_params_file:
        new_params_file.write(before)
        new_params_file.write(start_sep)
        new_params_file.write('\n\n'.join(params_to_write))
        new_params_file.write(end_sep)
        new_params_file.write(after)


277
278
279
280
def gen_parameter_code(
    config_hpp: Path,
    config_out_cpp: Path
) -> Tuple[List[Tuple[str, int]], List[List[Dict[str, List]]]]:
281
282
283
284
    """Generate auto config file.

    Parameters
    ----------
285
    config_hpp : pathlib.Path
286
        Path to the config header file.
287
    config_out_cpp : pathlib.Path
288
289
290
291
292
293
294
295
296
297
        Path to the auto config file.

    Returns
    -------
    infos : tuple
        Tuple with names and content of sections.
    """
    keys, infos = get_parameter_infos(config_hpp)
    names = get_names(infos)
    alias = get_alias(infos)
298
    names_with_aliases = defaultdict(list)
299
300
301
302
303
304
305
306
    str_to_write = r"""/*!
 * Copyright (c) 2018 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 *
 * \note
 * This file is auto generated by LightGBM\helpers\parameter_generator.py from LightGBM\include\LightGBM\config.h file.
 */
"""
Guolin Ke's avatar
Guolin Ke committed
307
308
    str_to_write += "#include<LightGBM/config.h>\nnamespace LightGBM {\n"
    # alias table
jcipar's avatar
jcipar committed
309
310
311
    str_to_write += "const std::unordered_map<std::string, std::string>& Config::alias_table() {\n"
    str_to_write += "  static std::unordered_map<std::string, std::string> aliases({\n"

Guolin Ke's avatar
Guolin Ke committed
312
    for pair in alias:
313
        str_to_write += f'  {{"{pair[0]}", "{pair[1]}"}},\n'
314
        names_with_aliases[pair[1]].append(pair[0])
jcipar's avatar
jcipar committed
315
316
317
318
    str_to_write += "  });\n"
    str_to_write += "  return aliases;\n"
    str_to_write += "}\n\n"

Guolin Ke's avatar
Guolin Ke committed
319
    # names
jcipar's avatar
jcipar committed
320
321
322
    str_to_write += "const std::unordered_set<std::string>& Config::parameter_set() {\n"
    str_to_write += "  static std::unordered_set<std::string> params({\n"

Guolin Ke's avatar
Guolin Ke committed
323
    for name in names:
324
        str_to_write += f'  "{name}",\n'
jcipar's avatar
jcipar committed
325
326
327
    str_to_write += "  });\n"
    str_to_write += "  return params;\n"
    str_to_write += "}\n\n"
Guolin Ke's avatar
Guolin Ke committed
328
329
    # from strings
    str_to_write += "void Config::GetMembersFromString(const std::unordered_map<std::string, std::string>& params) {\n"
330
    str_to_write += '  std::string tmp_str = "";\n'
Guolin Ke's avatar
Guolin Ke committed
331
332
333
334
    for x in infos:
        for y in x:
            if "[doc-only]" in y:
                continue
335
            param_type = y["inner_type"][0]
Guolin Ke's avatar
Guolin Ke committed
336
337
338
339
            name = y["name"][0]
            checks = []
            if "check" in y:
                checks = y["check"]
340
            tmp = set_one_var_from_string(name, param_type, checks)
Guolin Ke's avatar
Guolin Ke committed
341
342
            str_to_write += tmp
    # tails
343
    str_to_write = f"{str_to_write.strip()}\n}}\n\n"
Guolin Ke's avatar
Guolin Ke committed
344
345
346
347
    str_to_write += "std::string Config::SaveMembersToString() const {\n"
    str_to_write += "  std::stringstream str_buf;\n"
    for x in infos:
        for y in x:
348
            if "[doc-only]" in y or "[no-save]" in y:
Guolin Ke's avatar
Guolin Ke committed
349
                continue
350
            param_type = y["inner_type"][0]
Guolin Ke's avatar
Guolin Ke committed
351
            name = y["name"][0]
352
353
            if "vector" in param_type:
                if "int8" in param_type:
354
                    str_to_write += f'  str_buf << "[{name}: " << Common::Join(Common::ArrayCast<int8_t, int>({name}), ",") << "]\\n";\n'
Guolin Ke's avatar
Guolin Ke committed
355
                else:
356
                    str_to_write += f'  str_buf << "[{name}: " << Common::Join({name}, ",") << "]\\n";\n'
Guolin Ke's avatar
Guolin Ke committed
357
            else:
358
                str_to_write += f'  str_buf << "[{name}: " << {name} << "]\\n";\n'
Guolin Ke's avatar
Guolin Ke committed
359
360
361
    # tails
    str_to_write += "  return str_buf.str();\n"
    str_to_write += "}\n\n"
362

363
364
365
366
367
368
369
370
371
372
373
374
    str_to_write += """const std::unordered_map<std::string, std::vector<std::string>>& Config::parameter2aliases() {
  static std::unordered_map<std::string, std::vector<std::string>> map({"""
    for name in names:
        str_to_write += '\n    {"' + name + '", '
        if names_with_aliases[name]:
            str_to_write += '{"' + '", "'.join(names_with_aliases[name]) + '"}},'
        else:
            str_to_write += '{}},'
    str_to_write += """
  });
  return map;
}
375

376
"""
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    str_to_write += """const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
  static std::unordered_map<std::string, std::string> map({"""
    int_t_pat = re.compile(r'int\d+_t')
    # the following are stored as comma separated strings but are arrays in the wrappers
    overrides = {
        'categorical_feature': 'vector<int>',
        'ignore_column': 'vector<int>',
        'interaction_constraints': 'vector<vector<int>>',
    }
    for x in infos:
        for y in x:
            name = y["name"][0]
            if name == 'task':
                continue
            if name in overrides:
                param_type = overrides[name]
            else:
                param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '')
            str_to_write += '\n    {"' + name + '", "' + param_type + '"},'
    str_to_write += """
  });
  return map;
}

"""

403
    str_to_write += "}  // namespace LightGBM\n"
Guolin Ke's avatar
Guolin Ke committed
404
405
406
    with open(config_out_cpp, "w") as config_out_cpp_file:
        config_out_cpp_file.write(str_to_write)

407
408
    return keys, infos

Guolin Ke's avatar
Guolin Ke committed
409
410

if __name__ == "__main__":
411
    current_dir = Path(__file__).absolute().parent
412
413
414
    config_hpp = current_dir.parent / 'include' / 'LightGBM' / 'config.h'
    config_out_cpp = current_dir.parent / 'src' / 'io' / 'config_auto.cpp'
    params_rst = current_dir.parent / 'docs' / 'Parameters.rst'
415
416
    sections, descriptions = gen_parameter_code(config_hpp, config_out_cpp)
    gen_parameter_description(sections, descriptions, params_rst)