Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cd204a4a
Unverified
Commit
cd204a4a
authored
Aug 21, 2020
by
Jinjing Zhou
Committed by
GitHub
Aug 21, 2020
Browse files
[Distributed] Use barrier instead of sleep in DistDataloader (#2086)
* use barrier instead of sleep * lint
parent
6d212983
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
python/dgl/distributed/dist_dataloader.py
python/dgl/distributed/dist_dataloader.py
+9
-11
No files found.
python/dgl/distributed/dist_dataloader.py
View file @
cd204a4a
...
...
@@ -2,7 +2,6 @@
"""Multiprocess dataloader for distributed training"""
import
multiprocessing
as
mp
from
queue
import
Queue
import
time
import
traceback
from
.dist_context
import
get_sampler_pool
...
...
@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data):
DGL_GLOBAL_COLLATE_FNS
=
{}
DGL_GLOBAL_MP_QUEUES
=
{}
def
init_fn
(
name
,
collate_fn
,
queue
):
def
init_fn
(
barrier
,
name
,
collate_fn
,
queue
):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES
[
name
]
=
queue
DGL_GLOBAL_COLLATE_FNS
[
name
]
=
collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time
.
sleep
(
1
)
barrier
.
wait
()
return
1
def
cleanup_fn
(
name
):
def
cleanup_fn
(
barrier
,
name
):
"""Clean up the data of a dataloader in the worker process"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
...
...
@@ -44,7 +41,7 @@ def cleanup_fn(name):
del
DGL_GLOBAL_COLLATE_FNS
[
name
]
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time
.
sleep
(
1
)
barrier
.
wait
(
)
return
1
...
...
@@ -52,7 +49,7 @@ def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import
logging
logger
=
m
ultiprocessing
.
log_to_stderr
()
logger
=
m
p
.
log_to_stderr
()
logger
.
setLevel
(
logging
.
DEBUG
)
DATALOADER_ID
=
0
...
...
@@ -122,6 +119,7 @@ class DistDataLoader:
self
.
current_pos
=
0
if
self
.
pool
is
not
None
:
self
.
m
=
mp
.
Manager
()
self
.
barrier
=
self
.
m
.
Barrier
(
self
.
num_workers
)
self
.
queue
=
self
.
m
.
Queue
(
maxsize
=
queue_size
)
else
:
self
.
queue
=
Queue
(
maxsize
=
queue_size
)
...
...
@@ -145,7 +143,7 @@ class DistDataLoader:
results
=
[]
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
init_fn
,
args
=
(
self
.
name
,
self
.
collate_fn
,
self
.
queue
)))
init_fn
,
args
=
(
self
.
barrier
,
self
.
name
,
self
.
collate_fn
,
self
.
queue
)))
for
res
in
results
:
res
.
get
()
...
...
@@ -153,7 +151,7 @@ class DistDataLoader:
if
self
.
pool
is
not
None
:
results
=
[]
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
cleanup_fn
,
args
=
(
self
.
name
,)))
results
.
append
(
self
.
pool
.
apply_async
(
cleanup_fn
,
args
=
(
self
.
barrier
,
self
.
name
,)))
for
res
in
results
:
res
.
get
()
...
...
@@ -162,7 +160,7 @@ class DistDataLoader:
for
_
in
range
(
num_reqs
):
self
.
_request_next_batch
()
if
self
.
recv_idxs
<
self
.
expected_idxs
:
result
=
self
.
queue
.
get
(
timeout
=
9999
)
result
=
self
.
queue
.
get
(
timeout
=
1800
)
self
.
recv_idxs
+=
1
self
.
num_pending
-=
1
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment