Commit 5d744197 authored by SfinxCZ's avatar SfinxCZ Committed by Guolin Ke
Browse files

Fixed incorrect order in initialization of booster for distributed training. (#1741)

parent 17165f93
......@@ -1481,6 +1481,22 @@ class Booster(object):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
params_str = param_dict_to_str(params)
# set network if necessary
for alias in ["machines", "workers", "nodes"]:
if alias in params:
machines = params[alias]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
break
# construct booster object
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreate(
......@@ -1507,22 +1523,6 @@ class Booster(object):
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
# set network if necessary
for alias in ["machines", "workers", "nodes"]:
if alias in params:
machines = params[alias]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
break
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment