Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8bc91414
Unverified
Commit
8bc91414
authored
Aug 01, 2021
by
Jinjing Zhou
Committed by
GitHub
Aug 01, 2021
Browse files
[Distributed] Fix distributed training hang with multiple samplers (#3169)
Rewrite the multiprocessing worker pool
parent
2ec0493d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
170 additions
and
89 deletions
+170
-89
python/dgl/distributed/dist_context.py
python/dgl/distributed/dist_context.py
+151
-17
python/dgl/distributed/dist_dataloader.py
python/dgl/distributed/dist_dataloader.py
+19
-72
No files found.
python/dgl/distributed/dist_context.py
View file @
8bc91414
"""Initialize the distributed services"""
# pylint: disable=line-too-long
import
multiprocessing
as
mp
import
traceback
...
...
@@ -6,6 +7,8 @@ import atexit
import
time
import
os
import
sys
import
queue
from
enum
import
Enum
from
.
import
rpc
from
.constants
import
MAX_QUEUE_SIZE
...
...
@@ -18,15 +21,18 @@ SAMPLER_POOL = None
NUM_SAMPLER_WORKERS
=
0
INITIALIZED
=
False
def
set_initialized
(
value
=
True
):
"""Set the initialized state of rpc"""
global
INITIALIZED
INITIALIZED
=
value
def
get_sampler_pool
():
"""Return the sampler pool and num_workers"""
return
SAMPLER_POOL
,
NUM_SAMPLER_WORKERS
def
_init_rpc
(
ip_config
,
num_servers
,
max_queue_size
,
net_type
,
role
,
num_threads
):
''' This init function is called in the worker processes.
'''
...
...
@@ -41,6 +47,131 @@ def _init_rpc(ip_config, num_servers, max_queue_size, net_type, role, num_thread
traceback
.
print_exc
()
raise
e
class
MpCommand
(
Enum
):
"""Enum class for multiprocessing command"""
INIT_RPC
=
0
# Not used in the task queue
SET_COLLATE_FN
=
1
CALL_BARRIER
=
2
DELETE_COLLATE_FN
=
3
CALL_COLLATE_FN
=
4
CALL_FN_ALL_WORKERS
=
5
FINALIZE_POOL
=
6
def
init_process
(
rpc_config
,
mp_contexts
):
"""Work loop in the worker"""
try
:
_init_rpc
(
*
rpc_config
)
keep_polling
=
True
data_queue
,
task_queue
,
barrier
=
mp_contexts
collate_fn_dict
=
{}
while
keep_polling
:
try
:
# Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260
command
,
args
=
task_queue
.
get
(
timeout
=
5
)
except
queue
.
Empty
:
continue
if
command
==
MpCommand
.
SET_COLLATE_FN
:
dataloader_name
,
func
=
args
collate_fn_dict
[
dataloader_name
]
=
func
elif
command
==
MpCommand
.
CALL_BARRIER
:
barrier
.
wait
()
elif
command
==
MpCommand
.
DELETE_COLLATE_FN
:
dataloader_name
,
=
args
del
collate_fn_dict
[
dataloader_name
]
elif
command
==
MpCommand
.
CALL_COLLATE_FN
:
dataloader_name
,
collate_args
=
args
data_queue
.
put
(
(
dataloader_name
,
collate_fn_dict
[
dataloader_name
](
collate_args
)))
elif
command
==
MpCommand
.
CALL_FN_ALL_WORKERS
:
func
,
func_args
=
args
func
(
func_args
)
elif
command
==
MpCommand
.
FINALIZE_POOL
:
_exit
()
keep_polling
=
False
else
:
raise
Exception
(
"Unknown command"
)
except
Exception
as
e
:
traceback
.
print_exc
()
raise
e
class
CustomPool
:
"""Customized worker pool"""
def
__init__
(
self
,
num_workers
,
rpc_config
):
"""
Customized worker pool init function
"""
ctx
=
mp
.
get_context
(
"spawn"
)
self
.
num_workers
=
num_workers
self
.
queue_size
=
num_workers
*
4
self
.
result_queue
=
ctx
.
Queue
(
self
.
queue_size
)
self
.
task_queues
=
[]
self
.
process_list
=
[]
self
.
current_proc_id
=
0
self
.
cache_result_dict
=
{}
self
.
barrier
=
ctx
.
Barrier
(
num_workers
)
for
_
in
range
(
num_workers
):
task_queue
=
ctx
.
Queue
(
self
.
queue_size
)
self
.
task_queues
.
append
(
task_queue
)
proc
=
ctx
.
Process
(
target
=
init_process
,
args
=
(
rpc_config
,
(
self
.
result_queue
,
task_queue
,
self
.
barrier
)))
proc
.
daemon
=
True
proc
.
start
()
self
.
process_list
.
append
(
proc
)
def
set_collate_fn
(
self
,
func
,
dataloader_name
):
"""Set collate function in subprocess"""
for
i
in
range
(
self
.
num_workers
):
self
.
task_queues
[
i
].
put
(
(
MpCommand
.
SET_COLLATE_FN
,
(
dataloader_name
,
func
)))
def
submit_task
(
self
,
dataloader_name
,
args
):
"""Submit task to workers"""
# Round robin
self
.
task_queues
[
self
.
current_proc_id
].
put
(
(
MpCommand
.
CALL_COLLATE_FN
,
(
dataloader_name
,
args
)))
self
.
current_proc_id
=
(
self
.
current_proc_id
+
1
)
%
self
.
num_workers
def
submit_task_to_all_workers
(
self
,
func
,
args
):
"""Submit task to all workers"""
for
i
in
range
(
self
.
num_workers
):
self
.
task_queues
[
i
].
put
(
(
MpCommand
.
CALL_FN_ALL_WORKERS
,
(
func
,
args
)))
def
get_result
(
self
,
dataloader_name
,
timeout
=
1800
):
"""Get result from result queue"""
result_dataloader_name
,
result
=
self
.
result_queue
.
get
(
timeout
=
timeout
)
assert
result_dataloader_name
==
dataloader_name
return
result
def
delete_collate_fn
(
self
,
dataloader_name
):
"""Delete collate function"""
for
i
in
range
(
self
.
num_workers
):
self
.
task_queues
[
i
].
put
(
(
MpCommand
.
DELETE_COLLATE_FN
,
(
dataloader_name
,
)))
def
call_barrier
(
self
):
"""Call barrier at all workers"""
for
i
in
range
(
self
.
num_workers
):
self
.
task_queues
[
i
].
put
(
(
MpCommand
.
CALL_BARRIER
,
tuple
()))
def
close
(
self
):
"""Close worker pool"""
for
i
in
range
(
self
.
num_workers
):
self
.
task_queues
[
i
].
put
((
MpCommand
.
FINALIZE_POOL
,
tuple
()),
block
=
False
)
time
.
sleep
(
0.5
)
# Fix for early python version
def
join
(
self
):
"""Join the close process of worker pool"""
for
i
in
range
(
self
.
num_workers
):
self
.
process_list
[
i
].
join
()
def
initialize
(
ip_config
,
num_servers
=
1
,
num_workers
=
0
,
max_queue_size
=
MAX_QUEUE_SIZE
,
net_type
=
'socket'
,
num_worker_threads
=
1
):
...
...
@@ -84,15 +215,15 @@ def initialize(ip_config, num_servers=1, num_workers=0,
if
os
.
environ
.
get
(
'DGL_ROLE'
,
'client'
)
==
'server'
:
from
.dist_graph
import
DistGraphServer
assert
os
.
environ
.
get
(
'DGL_SERVER_ID'
)
is
not
None
,
\
'Please define DGL_SERVER_ID to run DistGraph server'
'Please define DGL_SERVER_ID to run DistGraph server'
assert
os
.
environ
.
get
(
'DGL_IP_CONFIG'
)
is
not
None
,
\
'Please define DGL_IP_CONFIG to run DistGraph server'
'Please define DGL_IP_CONFIG to run DistGraph server'
assert
os
.
environ
.
get
(
'DGL_NUM_SERVER'
)
is
not
None
,
\
'Please define DGL_NUM_SERVER to run DistGraph server'
'Please define DGL_NUM_SERVER to run DistGraph server'
assert
os
.
environ
.
get
(
'DGL_NUM_CLIENT'
)
is
not
None
,
\
'Please define DGL_NUM_CLIENT to run DistGraph server'
'Please define DGL_NUM_CLIENT to run DistGraph server'
assert
os
.
environ
.
get
(
'DGL_CONF_PATH'
)
is
not
None
,
\
'Please define DGL_CONF_PATH to run DistGraph server'
'Please define DGL_CONF_PATH to run DistGraph server'
formats
=
os
.
environ
.
get
(
'DGL_GRAPH_FORMAT'
,
'csc'
).
split
(
','
)
formats
=
[
f
.
strip
()
for
f
in
formats
]
serv
=
DistGraphServer
(
int
(
os
.
environ
.
get
(
'DGL_SERVER_ID'
)),
...
...
@@ -114,46 +245,47 @@ def initialize(ip_config, num_servers=1, num_workers=0,
num_servers
=
1
rpc
.
reset
()
ctx
=
mp
.
get_context
(
"spawn"
)
global
SAMPLER_POOL
global
NUM_SAMPLER_WORKERS
is_standalone
=
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
==
'standalone'
is_standalone
=
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
==
'standalone'
if
num_workers
>
0
and
not
is_standalone
:
SAMPLER_POOL
=
ctx
.
Pool
(
num_workers
,
initializer
=
_init_rpc
,
initargs
=
(
ip_config
,
num_servers
,
max_queue_size
,
net_type
,
'sampler'
,
num_worker_threads
))
SAMPLER_POOL
=
CustomPool
(
num_workers
,
(
ip_config
,
num_servers
,
max_queue_size
,
net_type
,
'sampler'
,
num_worker_threads
))
else
:
SAMPLER_POOL
=
None
NUM_SAMPLER_WORKERS
=
num_workers
if
not
is_standalone
:
assert
num_servers
is
not
None
and
num_servers
>
0
,
\
'The number of servers per machine must be specified with a positive number.'
'The number of servers per machine must be specified with a positive number.'
connect_to_server
(
ip_config
,
num_servers
,
max_queue_size
,
net_type
)
init_role
(
'default'
)
init_kvstore
(
ip_config
,
num_servers
,
'default'
)
def
finalize_client
():
"""Release resources of this client."""
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
!=
'standalone'
:
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
!=
'standalone'
:
rpc
.
finalize_sender
()
rpc
.
finalize_receiver
()
global
INITIALIZED
INITIALIZED
=
False
def
_exit
():
exit_client
()
time
.
sleep
(
1
)
def
finalize_worker
():
"""Finalize workers
Python's multiprocessing pool will not call atexit function when close
"""
global
SAMPLER_POOL
if
SAMPLER_POOL
is
not
None
:
for
_
in
range
(
NUM_SAMPLER_WORKERS
):
SAMPLER_POOL
.
apply_async
(
_exit
)
time
.
sleep
(
0.1
)
# This is necessary but I don't know why
SAMPLER_POOL
.
close
()
def
join_finalize_worker
():
"""join the worker close process"""
global
SAMPLER_POOL
...
...
@@ -161,11 +293,13 @@ def join_finalize_worker():
SAMPLER_POOL
.
join
()
SAMPLER_POOL
=
None
def
is_initialized
():
"""Is RPC initialized?
"""
return
INITIALIZED
def
exit_client
():
"""Trainer exits
...
...
@@ -177,8 +311,8 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again.
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker
()
# finalize workers should be earilier than barrier, and non-blocking
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
!=
'standalone'
:
finalize_worker
()
# finalize workers should be earilier than barrier, and non-blocking
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
!=
'standalone'
:
rpc
.
client_barrier
()
shutdown_servers
()
finalize_client
()
...
...
python/dgl/distributed/dist_dataloader.py
View file @
8bc91414
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import
multiprocessing
as
mp
from
queue
import
Queue
import
traceback
from
.dist_context
import
get_sampler_pool
from
..
import
backend
as
F
__all__
=
[
"DistDataLoader"
]
def
call_collate_fn
(
name
,
next_data
):
"""Call collate function"""
try
:
result
=
DGL_GLOBAL_COLLATE_FNS
[
name
](
next_data
)
DGL_GLOBAL_MP_QUEUES
[
name
].
put
(
result
)
except
Exception
as
e
:
traceback
.
print_exc
()
print
(
e
)
raise
e
return
1
DGL_GLOBAL_COLLATE_FNS
=
{}
DGL_GLOBAL_MP_QUEUES
=
{}
def
init_fn
(
barrier
,
name
,
collate_fn
,
queue
):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES
[
name
]
=
queue
DGL_GLOBAL_COLLATE_FNS
[
name
]
=
collate_fn
barrier
.
wait
()
return
1
def
cleanup_fn
(
barrier
,
name
):
"""Clean up the data of a dataloader in the worker process"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
del
DGL_GLOBAL_MP_QUEUES
[
name
]
del
DGL_GLOBAL_COLLATE_FNS
[
name
]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
barrier
.
wait
()
return
1
def
enable_mp_debug
():
"""Print multiprocessing debug information. This is only
for debug usage"""
import
logging
logger
=
mp
.
log_to_stderr
()
logger
.
setLevel
(
logging
.
DEBUG
)
DATALOADER_ID
=
0
class
DistDataLoader
:
"""DGL customized multiprocessing dataloader.
...
...
@@ -112,16 +66,12 @@ class DistDataLoader:
self
.
pool
,
self
.
num_workers
=
get_sampler_pool
()
if
queue_size
is
None
:
queue_size
=
self
.
num_workers
*
4
if
self
.
num_workers
>
0
else
4
self
.
queue_size
=
queue_size
self
.
queue_size
=
queue_size
# prefetch size
self
.
batch_size
=
batch_size
self
.
num_pending
=
0
self
.
collate_fn
=
collate_fn
self
.
current_pos
=
0
if
self
.
pool
is
not
None
:
m
=
mp
.
Manager
()
self
.
queue
=
m
.
Queue
(
maxsize
=
queue_size
)
else
:
self
.
queue
=
Queue
(
maxsize
=
queue_size
)
self
.
queue
=
[]
# Only used when pool is None
self
.
drop_last
=
drop_last
self
.
recv_idxs
=
0
self
.
shuffle
=
shuffle
...
...
@@ -140,34 +90,24 @@ class DistDataLoader:
DATALOADER_ID
+=
1
if
self
.
pool
is
not
None
:
results
=
[]
barrier
=
m
.
Barrier
(
self
.
num_workers
)
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
init_fn
,
args
=
(
barrier
,
self
.
name
,
self
.
collate_fn
,
self
.
queue
)))
for
res
in
results
:
res
.
get
()
self
.
pool
.
set_collate_fn
(
self
.
collate_fn
,
self
.
name
)
def
__del__
(
self
):
# When the process exits, the process pool may have been closed. We should try
# and get the process pool again and see if we need to clean up the process pool.
self
.
pool
,
self
.
num_workers
=
get_sampler_pool
()
if
self
.
pool
is
not
None
:
results
=
[]
# Here we need to create the manager and barrier again.
m
=
mp
.
Manager
()
barrier
=
m
.
Barrier
(
self
.
num_workers
)
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
cleanup_fn
,
args
=
(
barrier
,
self
.
name
,)))
for
res
in
results
:
res
.
get
()
self
.
pool
.
delete_collate_fn
(
self
.
name
)
def
__next__
(
self
):
num_reqs
=
self
.
queue_size
-
self
.
num_pending
if
self
.
pool
is
None
:
num_reqs
=
1
else
:
num_reqs
=
self
.
queue_size
-
self
.
num_pending
for
_
in
range
(
num_reqs
):
self
.
_request_next_batch
()
if
self
.
recv_idxs
<
self
.
expected_idxs
:
result
=
self
.
queue
.
get
(
timeout
=
1800
)
result
=
self
.
_get_data_from_result_queue
(
)
self
.
recv_idxs
+=
1
self
.
num_pending
-=
1
return
result
...
...
@@ -175,6 +115,13 @@ class DistDataLoader:
assert
self
.
num_pending
==
0
raise
StopIteration
def
_get_data_from_result_queue
(
self
,
timeout
=
1800
):
if
self
.
pool
is
None
:
ret
=
self
.
queue
.
pop
(
0
)
else
:
ret
=
self
.
pool
.
get_result
(
self
.
name
,
timeout
=
timeout
)
return
ret
def
__iter__
(
self
):
if
self
.
shuffle
:
self
.
data_idx
=
F
.
rand_shuffle
(
self
.
data_idx
)
...
...
@@ -188,10 +135,10 @@ class DistDataLoader:
if
next_data
is
None
:
return
elif
self
.
pool
is
not
None
:
self
.
pool
.
apply_async
(
call_collate_fn
,
args
=
(
self
.
name
,
next_data
,
)
)
self
.
pool
.
submit_task
(
self
.
name
,
next_data
)
else
:
result
=
self
.
collate_fn
(
next_data
)
self
.
queue
.
put
(
result
)
self
.
queue
.
append
(
result
)
self
.
num_pending
+=
1
def
_next_data
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment