Commit 4df7b21d authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

updated params aliases in Python (#1549)

parent 61191ed2
......@@ -1370,19 +1370,21 @@ class Booster(object):
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
# set network if necessary
if "machines" in params:
machines = params["machines"]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
for alias in ["machines", "workers", "nodes"]:
if alias in params:
machines = params[alias]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
break
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
......@@ -1521,7 +1523,7 @@ class Booster(object):
params : dict
New parameters for Booster.
"""
if 'metric' in params:
if any(metric_alias in params for metric_alias in ('metric', 'metrics', 'metric_types')):
self.__need_reload_eval_info = True
params_str = param_dict_to_str(params)
if params_str:
......
......@@ -127,7 +127,9 @@ def reset_parameter(**kwargs):
"""internal function"""
new_parameters = {}
for key, value in kwargs.items():
if key in ['num_class', 'boosting_type', 'metric']:
if key in ['num_class', 'num_classes',
'boosting', 'boost', 'boosting_type',
'metric', 'metrics', 'metric_types']:
raise RuntimeError("cannot reset {} during training".format(repr(key)))
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment