Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5d744197
Commit
5d744197
authored
Oct 11, 2018
by
SfinxCZ
Committed by
Guolin Ke
Oct 11, 2018
Browse files
Fixed incorrect order in initialization of booster for distributed training. (#1741)
parent
17165f93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+16
-16
No files found.
python-package/lightgbm/basic.py
View file @
5d744197
...
@@ -1481,6 +1481,22 @@ class Booster(object):
...
@@ -1481,6 +1481,22 @@ class Booster(object):
raise
TypeError
(
'Training data should be Dataset instance, met {}'
raise
TypeError
(
'Training data should be Dataset instance, met {}'
.
format
(
type
(
train_set
).
__name__
))
.
format
(
type
(
train_set
).
__name__
))
params_str
=
param_dict_to_str
(
params
)
params_str
=
param_dict_to_str
(
params
)
# set network if necessary
for
alias
in
[
"machines"
,
"workers"
,
"nodes"
]:
if
alias
in
params
:
machines
=
params
[
alias
]
if
isinstance
(
machines
,
string_type
):
num_machines
=
len
(
machines
.
split
(
','
))
elif
isinstance
(
machines
,
(
list
,
set
)):
num_machines
=
len
(
machines
)
machines
=
','
.
join
(
machines
)
else
:
raise
ValueError
(
"Invalid machines in params."
)
self
.
set_network
(
machines
,
local_listen_port
=
params
.
get
(
"local_listen_port"
,
12400
),
listen_time_out
=
params
.
get
(
"listen_time_out"
,
120
),
num_machines
=
params
.
get
(
"num_machines"
,
num_machines
))
break
# construct booster object
# construct booster object
self
.
handle
=
ctypes
.
c_void_p
()
self
.
handle
=
ctypes
.
c_void_p
()
_safe_call
(
_LIB
.
LGBM_BoosterCreate
(
_safe_call
(
_LIB
.
LGBM_BoosterCreate
(
...
@@ -1507,22 +1523,6 @@ class Booster(object):
...
@@ -1507,22 +1523,6 @@ class Booster(object):
self
.
__is_predicted_cur_iter
=
[
False
]
self
.
__is_predicted_cur_iter
=
[
False
]
self
.
__get_eval_info
()
self
.
__get_eval_info
()
self
.
pandas_categorical
=
train_set
.
pandas_categorical
self
.
pandas_categorical
=
train_set
.
pandas_categorical
# set network if necessary
for
alias
in
[
"machines"
,
"workers"
,
"nodes"
]:
if
alias
in
params
:
machines
=
params
[
alias
]
if
isinstance
(
machines
,
string_type
):
num_machines
=
len
(
machines
.
split
(
','
))
elif
isinstance
(
machines
,
(
list
,
set
)):
num_machines
=
len
(
machines
)
machines
=
','
.
join
(
machines
)
else
:
raise
ValueError
(
"Invalid machines in params."
)
self
.
set_network
(
machines
,
local_listen_port
=
params
.
get
(
"local_listen_port"
,
12400
),
listen_time_out
=
params
.
get
(
"listen_time_out"
,
120
),
num_machines
=
params
.
get
(
"num_machines"
,
num_machines
))
break
elif
model_file
is
not
None
:
elif
model_file
is
not
None
:
# Prediction task
# Prediction task
out_num_iterations
=
ctypes
.
c_int
(
0
)
out_num_iterations
=
ctypes
.
c_int
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment