Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ac57d5a4
Unverified
Commit
ac57d5a4
authored
Jun 20, 2023
by
José Morales
Committed by
GitHub
Jun 20, 2023
Browse files
[dask] hold ports until training (#5890)
parent
07e3cf47
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
111 deletions
+57
-111
python-package/lightgbm/compat.py
python-package/lightgbm/compat.py
+7
-1
python-package/lightgbm/dask.py
python-package/lightgbm/dask.py
+45
-70
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+5
-40
No files found.
python-package/lightgbm/compat.py
View file @
ac57d5a4
...
@@ -144,7 +144,7 @@ try:
...
@@ -144,7 +144,7 @@ try:
from
dask.bag
import
from_delayed
as
dask_bag_from_delayed
from
dask.bag
import
from_delayed
as
dask_bag_from_delayed
from
dask.dataframe
import
DataFrame
as
dask_DataFrame
from
dask.dataframe
import
DataFrame
as
dask_DataFrame
from
dask.dataframe
import
Series
as
dask_Series
from
dask.dataframe
import
Series
as
dask_Series
from
dask.distributed
import
Client
,
default_client
,
wait
from
dask.distributed
import
Client
,
Future
,
default_client
,
wait
DASK_INSTALLED
=
True
DASK_INSTALLED
=
True
except
ImportError
:
except
ImportError
:
DASK_INSTALLED
=
False
DASK_INSTALLED
=
False
...
@@ -161,6 +161,12 @@ except ImportError:
...
@@ -161,6 +161,12 @@ except ImportError:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
pass
class
Future
:
# type: ignore
"""Dummy class for dask.distributed.Future."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
class
dask_Array
:
# type: ignore
class
dask_Array
:
# type: ignore
"""Dummy class for dask.array.Array."""
"""Dummy class for dask.array.Array."""
...
...
python-package/lightgbm/dask.py
View file @
ac57d5a4
...
@@ -6,6 +6,7 @@ dask.Array and dask.DataFrame collections.
...
@@ -6,6 +6,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
"""
import
operator
import
socket
import
socket
from
collections
import
defaultdict
from
collections
import
defaultdict
from
copy
import
deepcopy
from
copy
import
deepcopy
...
@@ -18,7 +19,7 @@ import numpy as np
...
@@ -18,7 +19,7 @@ import numpy as np
import
scipy.sparse
as
ss
import
scipy.sparse
as
ss
from
.basic
import
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_log_info
,
_log_warning
from
.basic
import
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_log_info
,
_log_warning
from
.compat
import
(
DASK_INSTALLED
,
PANDAS_INSTALLED
,
SKLEARN_INSTALLED
,
Client
,
LGBMNotFittedError
,
concat
,
from
.compat
import
(
DASK_INSTALLED
,
PANDAS_INSTALLED
,
SKLEARN_INSTALLED
,
Client
,
Future
,
LGBMNotFittedError
,
concat
,
dask_Array
,
dask_array_from_delayed
,
dask_bag_from_delayed
,
dask_DataFrame
,
dask_Series
,
dask_Array
,
dask_array_from_delayed
,
dask_bag_from_delayed
,
dask_DataFrame
,
dask_Series
,
default_client
,
delayed
,
pd_DataFrame
,
pd_Series
,
wait
)
default_client
,
delayed
,
pd_DataFrame
,
pd_Series
,
wait
)
from
.sklearn
import
(
LGBMClassifier
,
LGBMModel
,
LGBMRanker
,
LGBMRegressor
,
_LGBM_ScikitCustomObjectiveFunction
,
from
.sklearn
import
(
LGBMClassifier
,
LGBMModel
,
LGBMRanker
,
LGBMRegressor
,
_LGBM_ScikitCustomObjectiveFunction
,
...
@@ -38,18 +39,21 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
...
@@ -38,18 +39,21 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype
=
Union
[
Type
[
np
.
float32
],
Type
[
np
.
float64
],
Type
[
np
.
int32
],
Type
[
np
.
int64
]]
_PredictionDtype
=
Union
[
Type
[
np
.
float32
],
Type
[
np
.
float64
],
Type
[
np
.
int32
],
Type
[
np
.
int64
]]
class
_HostWorkers
:
class
_RemoteSocket
:
def
acquire
(
self
)
->
int
:
self
.
socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
socket
.
bind
((
''
,
0
))
return
self
.
socket
.
getsockname
()[
1
]
def
__init__
(
self
,
default
:
str
,
all_workers
:
List
[
str
]):
def
release
(
self
)
->
None
:
self
.
default
=
default
self
.
socket
.
close
()
self
.
all_workers
=
all_workers
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
(
def
_acquire_port
()
->
Tuple
[
_RemoteSocket
,
int
]:
isinstance
(
other
,
type
(
self
))
s
=
_RemoteSocket
()
and
self
.
default
==
other
.
default
port
=
s
.
acquire
()
and
self
.
all_workers
==
other
.
all_workers
return
s
,
port
)
class
_DatasetNames
(
Enum
):
class
_DatasetNames
(
Enum
):
...
@@ -83,73 +87,40 @@ def _get_dask_client(client: Optional[Client]) -> Client:
...
@@ -83,73 +87,40 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return
client
return
client
def
_find_n_open_ports
(
n
:
int
)
->
List
[
int
]:
"""Find n random open ports on localhost.
Returns
-------
ports : list of int
n random open ports on localhost.
"""
sockets
=
[]
for
_
in
range
(
n
):
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
s
.
bind
((
''
,
0
))
sockets
.
append
(
s
)
ports
=
[]
for
s
in
sockets
:
ports
.
append
(
s
.
getsockname
()[
1
])
s
.
close
()
return
ports
def
_group_workers_by_host
(
worker_addresses
:
Iterable
[
str
])
->
Dict
[
str
,
_HostWorkers
]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers
:
Dict
[
str
,
_HostWorkers
]
=
{}
for
address
in
worker_addresses
:
hostname
=
urlparse
(
address
).
hostname
if
not
hostname
:
raise
ValueError
(
f
"Could not parse host name from worker address '
{
address
}
'"
)
if
hostname
not
in
host_to_workers
:
host_to_workers
[
hostname
]
=
_HostWorkers
(
default
=
address
,
all_workers
=
[
address
])
else
:
host_to_workers
[
hostname
].
all_workers
.
append
(
address
)
return
host_to_workers
def
_assign_open_ports_to_workers
(
def
_assign_open_ports_to_workers
(
client
:
Client
,
client
:
Client
,
host_to_
workers
:
Dic
t
[
str
,
_HostWorkers
]
workers
:
Lis
t
[
str
]
,
)
->
Dict
[
str
,
int
]:
)
->
Tuple
[
Dict
[
str
,
Future
],
Dict
[
str
,
int
]
]
:
"""Assign an open port to each worker.
"""Assign an open port to each worker.
Returns
Returns
-------
-------
worker_to_socket_future: dict
mapping from worker address to a future pointing to the remote socket.
worker_to_port: dict
worker_to_port: dict
mapping from worker address to an open port.
mapping from worker address to an open port
in the worker's host
.
"""
"""
host_ports_futures
=
{}
# Acquire port in worker
for
hostname
,
workers
in
host_to_workers
.
items
():
worker_to_future
=
{}
n_workers_in_host
=
len
(
workers
.
all_workers
)
for
worker
in
workers
:
host_ports_futures
[
hostname
]
=
client
.
submit
(
worker_to_future
[
worker
]
=
client
.
submit
(
_find_n_open_ports
,
_acquire_port
,
n
=
n_workers_in_host
,
workers
=
[
worker
],
workers
=
[
workers
.
default
],
pure
=
False
,
allow_other_workers
=
False
,
allow_other_workers
=
False
,
pure
=
False
,
)
)
found_ports
=
client
.
gather
(
host_ports_futures
)
worker_to_port
=
{}
# schedule futures to retrieve each element of the tuple
for
hostname
,
workers
in
host_to_workers
.
items
():
worker_to_socket_future
=
{}
for
worker
,
port
in
zip
(
workers
.
all_workers
,
found_ports
[
hostname
]):
worker_to_port_future
=
{}
worker_to_port
[
worker
]
=
port
for
worker
,
socket_future
in
worker_to_future
.
items
():
return
worker_to_port
worker_to_socket_future
[
worker
]
=
client
.
submit
(
operator
.
itemgetter
(
0
),
socket_future
)
worker_to_port_future
[
worker
]
=
client
.
submit
(
operator
.
itemgetter
(
1
),
socket_future
)
# retrieve ports
worker_to_port
=
client
.
gather
(
worker_to_port_future
)
return
worker_to_socket_future
,
worker_to_port
def
_concat
(
seq
:
List
[
_DaskPart
])
->
_DaskPart
:
def
_concat
(
seq
:
List
[
_DaskPart
])
->
_DaskPart
:
...
@@ -190,6 +161,7 @@ def _train_part(
...
@@ -190,6 +161,7 @@ def _train_part(
num_machines
:
int
,
num_machines
:
int
,
return_model
:
bool
,
return_model
:
bool
,
time_out
:
int
,
time_out
:
int
,
remote_socket
:
_RemoteSocket
,
**
kwargs
:
Any
**
kwargs
:
Any
)
->
Optional
[
LGBMModel
]:
)
->
Optional
[
LGBMModel
]:
network_params
=
{
network_params
=
{
...
@@ -320,6 +292,8 @@ def _train_part(
...
@@ -320,6 +292,8 @@ def _train_part(
kwargs
[
'eval_class_weight'
]
=
[
eval_class_weight
[
i
]
for
i
in
eval_component_idx
]
kwargs
[
'eval_class_weight'
]
=
[
eval_class_weight
[
i
]
for
i
in
eval_component_idx
]
model
=
model_factory
(
**
params
)
model
=
model_factory
(
**
params
)
if
remote_socket
is
not
None
:
remote_socket
.
release
()
try
:
try
:
if
is_ranker
:
if
is_ranker
:
model
.
fit
(
model
.
fit
(
...
@@ -777,6 +751,7 @@ def _train(
...
@@ -777,6 +751,7 @@ def _train(
machines
=
params
.
pop
(
"machines"
)
machines
=
params
.
pop
(
"machines"
)
# figure out network params
# figure out network params
worker_to_socket_future
:
Dict
[
str
,
Future
]
=
{}
worker_addresses
=
worker_map
.
keys
()
worker_addresses
=
worker_map
.
keys
()
if
machines
is
not
None
:
if
machines
is
not
None
:
_log_info
(
"Using passed-in 'machines' parameter"
)
_log_info
(
"Using passed-in 'machines' parameter"
)
...
@@ -802,8 +777,7 @@ def _train(
...
@@ -802,8 +777,7 @@ def _train(
}
}
else
:
else
:
_log_info
(
"Finding random open ports for workers"
)
_log_info
(
"Finding random open ports for workers"
)
host_to_workers
=
_group_workers_by_host
(
worker_map
.
keys
())
worker_to_socket_future
,
worker_address_to_port
=
_assign_open_ports_to_workers
(
client
,
list
(
worker_map
.
keys
()))
worker_address_to_port
=
_assign_open_ports_to_workers
(
client
,
host_to_workers
)
machines
=
','
.
join
([
machines
=
','
.
join
([
f
'
{
urlparse
(
worker_address
).
hostname
}
:
{
port
}
'
f
'
{
urlparse
(
worker_address
).
hostname
}
:
{
port
}
'
...
@@ -831,6 +805,7 @@ def _train(
...
@@ -831,6 +805,7 @@ def _train(
local_listen_port
=
worker_address_to_port
[
worker
],
local_listen_port
=
worker_address_to_port
[
worker
],
num_machines
=
num_machines
,
num_machines
=
num_machines
,
time_out
=
params
.
get
(
'time_out'
,
120
),
time_out
=
params
.
get
(
'time_out'
,
120
),
remote_socket
=
worker_to_socket_future
.
get
(
worker
,
None
),
return_model
=
(
worker
==
master_worker
),
return_model
=
(
worker
==
master_worker
),
workers
=
[
worker
],
workers
=
[
worker
],
allow_other_workers
=
False
,
allow_other_workers
=
False
,
...
...
tests/python_package_test/test_dask.py
View file @
ac57d5a4
...
@@ -519,26 +519,6 @@ def test_classifier_custom_objective(output, task, cluster):
...
@@ -519,26 +519,6 @@ def test_classifier_custom_objective(output, task, cluster):
assert_eq
(
p1_proba
,
p1_proba_local
)
assert_eq
(
p1_proba
,
p1_proba_local
)
def
test_group_workers_by_host
():
hosts
=
[
f
'0.0.0.
{
i
}
'
for
i
in
range
(
2
)]
workers
=
[
f
'tcp://
{
host
}
:
{
p
}
'
for
p
in
range
(
2
)
for
host
in
hosts
]
expected
=
{
host
:
lgb
.
dask
.
_HostWorkers
(
default
=
f
'tcp://
{
host
}
:0'
,
all_workers
=
[
f
'tcp://
{
host
}
:0'
,
f
'tcp://
{
host
}
:1'
]
)
for
host
in
hosts
}
host_to_workers
=
lgb
.
dask
.
_group_workers_by_host
(
workers
)
assert
host_to_workers
==
expected
def
test_group_workers_by_host_unparseable_host_names
():
workers_without_protocol
=
[
'0.0.0.1:80'
,
'0.0.0.2:80'
]
with
pytest
.
raises
(
ValueError
,
match
=
"Could not parse host name from worker address '0.0.0.1:80'"
):
lgb
.
dask
.
_group_workers_by_host
(
workers_without_protocol
)
def
test_machines_to_worker_map_unparseable_host_names
():
def
test_machines_to_worker_map_unparseable_host_names
():
workers
=
{
'0.0.0.1:80'
:
{},
'0.0.0.2:80'
:
{}}
workers
=
{
'0.0.0.1:80'
:
{},
'0.0.0.2:80'
:
{}}
machines
=
"0.0.0.1:80,0.0.0.2:80"
machines
=
"0.0.0.1:80,0.0.0.2:80"
...
@@ -546,23 +526,6 @@ def test_machines_to_worker_map_unparseable_host_names():
...
@@ -546,23 +526,6 @@ def test_machines_to_worker_map_unparseable_host_names():
lgb
.
dask
.
_machines_to_worker_map
(
machines
=
machines
,
worker_addresses
=
workers
.
keys
())
lgb
.
dask
.
_machines_to_worker_map
(
machines
=
machines
,
worker_addresses
=
workers
.
keys
())
def
test_assign_open_ports_to_workers
(
cluster
):
with
Client
(
cluster
)
as
client
:
workers
=
client
.
scheduler_info
()[
'workers'
].
keys
()
n_workers
=
len
(
workers
)
host_to_workers
=
lgb
.
dask
.
_group_workers_by_host
(
workers
)
for
_
in
range
(
25
):
worker_address_to_port
=
lgb
.
dask
.
_assign_open_ports_to_workers
(
client
,
host_to_workers
)
found_ports
=
worker_address_to_port
.
values
()
assert
len
(
found_ports
)
==
n_workers
# check that found ports are different for same address (LocalCluster)
assert
len
(
set
(
found_ports
))
==
len
(
found_ports
)
# check that the ports are indeed open
for
port
in
found_ports
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
''
,
port
))
def
test_training_does_not_fail_on_port_conflicts
(
cluster
):
def
test_training_does_not_fail_on_port_conflicts
(
cluster
):
with
Client
(
cluster
)
as
client
:
with
Client
(
cluster
)
as
client
:
_
,
_
,
_
,
_
,
dX
,
dy
,
dw
,
_
=
_create_data
(
'binary-classification'
,
output
=
'array'
)
_
,
_
,
_
,
_
,
dX
,
dy
,
dw
,
_
=
_create_data
(
'binary-classification'
,
output
=
'array'
)
...
@@ -1588,15 +1551,17 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
...
@@ -1588,15 +1551,17 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert
'machines'
not
in
params
assert
'machines'
not
in
params
# model 2 - machines given
# model 2 - machines given
workers
=
list
(
client
.
scheduler_info
()[
'workers'
])
workers_hostname
=
_get_workers_hostname
(
cluster
)
workers_hostname
=
_get_workers_hostname
(
cluster
)
n_workers
=
len
(
client
.
scheduler_info
()[
'workers'
])
remote_sockets
,
open_ports
=
lgb
.
dask
.
_assign_open_ports_to_workers
(
client
,
workers
)
open_ports
=
lgb
.
dask
.
_find_n_open_ports
(
n_workers
)
for
s
in
remote_sockets
.
values
():
s
.
release
()
dask_model2
=
dask_model_factory
(
dask_model2
=
dask_model_factory
(
n_estimators
=
5
,
n_estimators
=
5
,
num_leaves
=
5
,
num_leaves
=
5
,
machines
=
","
.
join
([
machines
=
","
.
join
([
f
"
{
workers_hostname
}
:
{
port
}
"
f
"
{
workers_hostname
}
:
{
port
}
"
for
port
in
open_ports
for
port
in
open_ports
.
values
()
]),
]),
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment