Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cd204a4a
Unverified
Commit
cd204a4a
authored
Aug 21, 2020
by
Jinjing Zhou
Committed by
GitHub
Aug 21, 2020
Browse files
[Distributed] Use barrier instead of sleep in DistDataloader (#2086)
* use barrier instead of sleep * lint
parent
6d212983
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
python/dgl/distributed/dist_dataloader.py
python/dgl/distributed/dist_dataloader.py
+9
-11
No files found.
python/dgl/distributed/dist_dataloader.py
View file @
cd204a4a
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
"""Multiprocess dataloader for distributed training"""
"""Multiprocess dataloader for distributed training"""
import
multiprocessing
as
mp
import
multiprocessing
as
mp
from
queue
import
Queue
from
queue
import
Queue
import
time
import
traceback
import
traceback
from
.dist_context
import
get_sampler_pool
from
.dist_context
import
get_sampler_pool
...
@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data):
...
@@ -25,18 +24,16 @@ def call_collate_fn(name, next_data):
DGL_GLOBAL_COLLATE_FNS
=
{}
DGL_GLOBAL_COLLATE_FNS
=
{}
DGL_GLOBAL_MP_QUEUES
=
{}
DGL_GLOBAL_MP_QUEUES
=
{}
def
init_fn
(
name
,
collate_fn
,
queue
):
def
init_fn
(
barrier
,
name
,
collate_fn
,
queue
):
"""Initialize setting collate function and mp.Queue in the subprocess"""
"""Initialize setting collate function and mp.Queue in the subprocess"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
global
DGL_GLOBAL_MP_QUEUES
DGL_GLOBAL_MP_QUEUES
[
name
]
=
queue
DGL_GLOBAL_MP_QUEUES
[
name
]
=
queue
DGL_GLOBAL_COLLATE_FNS
[
name
]
=
collate_fn
DGL_GLOBAL_COLLATE_FNS
[
name
]
=
collate_fn
# sleep here is to ensure this function is executed in all worker processes
barrier
.
wait
()
# probably need better solution in the future
time
.
sleep
(
1
)
return
1
return
1
def
cleanup_fn
(
name
):
def
cleanup_fn
(
barrier
,
name
):
"""Clean up the data of a dataloader in the worker process"""
"""Clean up the data of a dataloader in the worker process"""
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_COLLATE_FNS
global
DGL_GLOBAL_MP_QUEUES
global
DGL_GLOBAL_MP_QUEUES
...
@@ -44,7 +41,7 @@ def cleanup_fn(name):
...
@@ -44,7 +41,7 @@ def cleanup_fn(name):
del
DGL_GLOBAL_COLLATE_FNS
[
name
]
del
DGL_GLOBAL_COLLATE_FNS
[
name
]
# sleep here is to ensure this function is executed in all worker processes
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
# probably need better solution in the future
time
.
sleep
(
1
)
barrier
.
wait
(
)
return
1
return
1
...
@@ -52,7 +49,7 @@ def enable_mp_debug():
...
@@ -52,7 +49,7 @@ def enable_mp_debug():
"""Print multiprocessing debug information. This is only
"""Print multiprocessing debug information. This is only
for debug usage"""
for debug usage"""
import
logging
import
logging
logger
=
m
ultiprocessing
.
log_to_stderr
()
logger
=
m
p
.
log_to_stderr
()
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
setLevel
(
logging
.
DEBUG
)
DATALOADER_ID
=
0
DATALOADER_ID
=
0
...
@@ -122,6 +119,7 @@ class DistDataLoader:
...
@@ -122,6 +119,7 @@ class DistDataLoader:
self
.
current_pos
=
0
self
.
current_pos
=
0
if
self
.
pool
is
not
None
:
if
self
.
pool
is
not
None
:
self
.
m
=
mp
.
Manager
()
self
.
m
=
mp
.
Manager
()
self
.
barrier
=
self
.
m
.
Barrier
(
self
.
num_workers
)
self
.
queue
=
self
.
m
.
Queue
(
maxsize
=
queue_size
)
self
.
queue
=
self
.
m
.
Queue
(
maxsize
=
queue_size
)
else
:
else
:
self
.
queue
=
Queue
(
maxsize
=
queue_size
)
self
.
queue
=
Queue
(
maxsize
=
queue_size
)
...
@@ -145,7 +143,7 @@ class DistDataLoader:
...
@@ -145,7 +143,7 @@ class DistDataLoader:
results
=
[]
results
=
[]
for
_
in
range
(
self
.
num_workers
):
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
results
.
append
(
self
.
pool
.
apply_async
(
init_fn
,
args
=
(
self
.
name
,
self
.
collate_fn
,
self
.
queue
)))
init_fn
,
args
=
(
self
.
barrier
,
self
.
name
,
self
.
collate_fn
,
self
.
queue
)))
for
res
in
results
:
for
res
in
results
:
res
.
get
()
res
.
get
()
...
@@ -153,7 +151,7 @@ class DistDataLoader:
...
@@ -153,7 +151,7 @@ class DistDataLoader:
if
self
.
pool
is
not
None
:
if
self
.
pool
is
not
None
:
results
=
[]
results
=
[]
for
_
in
range
(
self
.
num_workers
):
for
_
in
range
(
self
.
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
cleanup_fn
,
args
=
(
self
.
name
,)))
results
.
append
(
self
.
pool
.
apply_async
(
cleanup_fn
,
args
=
(
self
.
barrier
,
self
.
name
,)))
for
res
in
results
:
for
res
in
results
:
res
.
get
()
res
.
get
()
...
@@ -162,7 +160,7 @@ class DistDataLoader:
...
@@ -162,7 +160,7 @@ class DistDataLoader:
for
_
in
range
(
num_reqs
):
for
_
in
range
(
num_reqs
):
self
.
_request_next_batch
()
self
.
_request_next_batch
()
if
self
.
recv_idxs
<
self
.
expected_idxs
:
if
self
.
recv_idxs
<
self
.
expected_idxs
:
result
=
self
.
queue
.
get
(
timeout
=
9999
)
result
=
self
.
queue
.
get
(
timeout
=
1800
)
self
.
recv_idxs
+=
1
self
.
recv_idxs
+=
1
self
.
num_pending
-=
1
self
.
num_pending
-=
1
return
result
return
result
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment