Commit 95519f36 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] add network config api (#1019)

* add network

* update doc
parent 36f4c13e
...@@ -1245,6 +1245,7 @@ class Booster(object): ...@@ -1245,6 +1245,7 @@ class Booster(object):
Whether to print messages during construction. Whether to print messages during construction.
""" """
self.handle = None self.handle = None
self.network = False
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
self.__train_data_name = "training" self.__train_data_name = "training"
self.__attr = {} self.__attr = {}
...@@ -1288,6 +1289,20 @@ class Booster(object): ...@@ -1288,6 +1289,20 @@ class Booster(object):
self.__is_predicted_cur_iter = [False] self.__is_predicted_cur_iter = [False]
self.__get_eval_info() self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical self.pandas_categorical = train_set.pandas_categorical
"""set network if necessary"""
if "machines" in params:
machines = params["machines"]
if isinstance(machines, string_type):
num_machines = len(machines.split(','))
elif isinstance(machines, (list, set)):
num_machines = len(machines)
machines = ','.join(machines)
else:
raise ValueError("Invalid machines in params.")
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
elif model_file is not None: elif model_file is not None:
"""Prediction task""" """Prediction task"""
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
...@@ -1308,6 +1323,8 @@ class Booster(object): ...@@ -1308,6 +1323,8 @@ class Booster(object):
raise TypeError('Need at least one training dataset or model file to create booster instance') raise TypeError('Need at least one training dataset or model file to create booster instance')
def __del__(self): def __del__(self):
if self.network:
self.free_network()
if self.handle is not None: if self.handle is not None:
_safe_call(_LIB.LGBM_BoosterFree(self.handle)) _safe_call(_LIB.LGBM_BoosterFree(self.handle))
...@@ -1351,6 +1368,32 @@ class Booster(object): ...@@ -1351,6 +1368,32 @@ class Booster(object):
self.__inner_predict_buffer = [] self.__inner_predict_buffer = []
self.__is_predicted_cur_iter = [] self.__is_predicted_cur_iter = []
def set_network(self, machines, local_listen_port=12400,
listen_time_out=120, num_machines=1):
"""Set the network configuration.
Parameters
----------
machines: list, set or string
Names of machines.
local_listen_port: int, optional (default=12400)
TCP listen port for local machines.
listen_time_out: int, optional (default=120)
Socket time-out in minutes.
num_machines: int, optional (default=1)
The number of machines for parallel learning application.
"""
_safe_call(_LIB.LGBM_NetworkInit(c_str(machines),
ctypes.c_int(local_listen_port),
ctypes.c_int(listen_time_out),
ctypes.c_int(num_machines)))
self.network = True
def free_network(self):
"""Free Network."""
_safe_call(_LIB.LGBM_NetworkFree())
self.network = False
def set_train_data_name(self, name): def set_train_data_name(self, name):
"""Set the name to the training Dataset. """Set the name to the training Dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment