Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
95519f36
Commit
95519f36
authored
Oct 26, 2017
by
wxchan
Committed by
Guolin Ke
Oct 26, 2017
Browse files
[python] add network config api (#1019)
* add network * update doc
parent
36f4c13e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
0 deletions
+43
-0
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+43
-0
No files found.
python-package/lightgbm/basic.py
View file @
95519f36
...
...
@@ -1245,6 +1245,7 @@ class Booster(object):
Whether to print messages during construction.
"""
self
.
handle
=
None
self
.
network
=
False
self
.
__need_reload_eval_info
=
True
self
.
__train_data_name
=
"training"
self
.
__attr
=
{}
...
...
@@ -1288,6 +1289,20 @@ class Booster(object):
self
.
__is_predicted_cur_iter
=
[
False
]
self
.
__get_eval_info
()
self
.
pandas_categorical
=
train_set
.
pandas_categorical
"""set network if necessary"""
if
"machines"
in
params
:
machines
=
params
[
"machines"
]
if
isinstance
(
machines
,
string_type
):
num_machines
=
len
(
machines
.
split
(
','
))
elif
isinstance
(
machines
,
(
list
,
set
)):
num_machines
=
len
(
machines
)
machines
=
','
.
join
(
machines
)
else
:
raise
ValueError
(
"Invalid machines in params."
)
self
.
set_network
(
machines
,
local_listen_port
=
params
.
get
(
"local_listen_port"
,
12400
),
listen_time_out
=
params
.
get
(
"listen_time_out"
,
120
),
num_machines
=
params
.
get
(
"num_machines"
,
num_machines
))
elif
model_file
is
not
None
:
"""Prediction task"""
out_num_iterations
=
ctypes
.
c_int
(
0
)
...
...
@@ -1308,6 +1323,8 @@ class Booster(object):
raise
TypeError
(
'Need at least one training dataset or model file to create booster instance'
)
def
__del__
(
self
):
if
self
.
network
:
self
.
free_network
()
if
self
.
handle
is
not
None
:
_safe_call
(
_LIB
.
LGBM_BoosterFree
(
self
.
handle
))
...
...
@@ -1351,6 +1368,32 @@ class Booster(object):
self
.
__inner_predict_buffer
=
[]
self
.
__is_predicted_cur_iter
=
[]
def
set_network
(
self
,
machines
,
local_listen_port
=
12400
,
listen_time_out
=
120
,
num_machines
=
1
):
"""Set the network configuration.
Parameters
----------
machines: list, set or string
Names of machines.
local_listen_port: int, optional (default=12400)
TCP listen port for local machines.
listen_time_out: int, optional (default=120)
Socket time-out in minutes.
num_machines: int, optional (default=1)
The number of machines for parallel learning application.
"""
_safe_call
(
_LIB
.
LGBM_NetworkInit
(
c_str
(
machines
),
ctypes
.
c_int
(
local_listen_port
),
ctypes
.
c_int
(
listen_time_out
),
ctypes
.
c_int
(
num_machines
)))
self
.
network
=
True
def
free_network
(
self
):
"""Free Network."""
_safe_call
(
_LIB
.
LGBM_NetworkFree
())
self
.
network
=
False
def
set_train_data_name
(
self
,
name
):
"""Set the name to the training Dataset.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment