Commit 4df7b21d authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

updated params aliases in Python (#1549)

parent 61191ed2
...@@ -1370,8 +1370,9 @@ class Booster(object): ...@@ -1370,8 +1370,9 @@ class Booster(object):
self.__get_eval_info() self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical self.pandas_categorical = train_set.pandas_categorical
# set network if necessary # set network if necessary
if "machines" in params: for alias in ["machines", "workers", "nodes"]:
machines = params["machines"] if alias in params:
machines = params[alias]
if isinstance(machines, string_type): if isinstance(machines, string_type):
num_machines = len(machines.split(',')) num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)): elif isinstance(machines, (list, set)):
...@@ -1383,6 +1384,7 @@ class Booster(object): ...@@ -1383,6 +1384,7 @@ class Booster(object):
local_listen_port=params.get("local_listen_port", 12400), local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120), listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines)) num_machines=params.get("num_machines", num_machines))
break
elif model_file is not None: elif model_file is not None:
# Prediction task # Prediction task
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
...@@ -1521,7 +1523,7 @@ class Booster(object): ...@@ -1521,7 +1523,7 @@ class Booster(object):
params : dict params : dict
New parameters for Booster. New parameters for Booster.
""" """
if 'metric' in params: if any(metric_alias in params for metric_alias in ('metric', 'metrics', 'metric_types')):
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
if params_str: if params_str:
......
...@@ -127,7 +127,9 @@ def reset_parameter(**kwargs): ...@@ -127,7 +127,9 @@ def reset_parameter(**kwargs):
"""internal function""" """internal function"""
new_parameters = {} new_parameters = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in ['num_class', 'boosting_type', 'metric']: if key in ['num_class', 'num_classes',
'boosting', 'boost', 'boosting_type',
'metric', 'metrics', 'metric_types']:
raise RuntimeError("cannot reset {} during training".format(repr(key))) raise RuntimeError("cannot reset {} during training".format(repr(key)))
if isinstance(value, list): if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration: if len(value) != env.end_iteration - env.begin_iteration:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment