Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
71cf7865
Unverified
Commit
71cf7865
authored
Jul 28, 2020
by
Jinjing Zhou
Committed by
GitHub
Jul 28, 2020
Browse files
Revert "[Distributed] DistDataloader (#1870)" (#1876)
This reverts commit
6557291f
.
parent
6557291f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
20 additions
and
335 deletions
+20
-335
python/dgl/distributed/__init__.py
python/dgl/distributed/__init__.py
+1
-2
python/dgl/distributed/dist_dataloader.py
python/dgl/distributed/dist_dataloader.py
+0
-152
python/dgl/distributed/dist_graph.py
python/dgl/distributed/dist_graph.py
+13
-32
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+2
-1
python/dgl/distributed/kvstore.py
python/dgl/distributed/kvstore.py
+2
-0
python/dgl/distributed/rpc.py
python/dgl/distributed/rpc.py
+1
-0
python/dgl/distributed/rpc_client.py
python/dgl/distributed/rpc_client.py
+1
-29
tests/distributed/test_mp_dataloader.py
tests/distributed/test_mp_dataloader.py
+0
-119
No files found.
python/dgl/distributed/__init__.py
View file @
71cf7865
...
@@ -9,10 +9,9 @@ from .sparse_emb import SparseAdagrad, DistEmbedding
...
@@ -9,10 +9,9 @@ from .sparse_emb import SparseAdagrad, DistEmbedding
from
.rpc
import
*
from
.rpc
import
*
from
.rpc_server
import
start_server
from
.rpc_server
import
start_server
from
.rpc_client
import
connect_to_server
,
exit_client
,
init_rpc
from
.rpc_client
import
connect_to_server
,
exit_client
from
.kvstore
import
KVServer
,
KVClient
from
.kvstore
import
KVServer
,
KVClient
from
.server_state
import
ServerState
from
.server_state
import
ServerState
from
.dist_dataloader
import
DistDataLoader
from
.graph_services
import
sample_neighbors
,
in_subgraph
from
.graph_services
import
sample_neighbors
,
in_subgraph
if
os
.
environ
.
get
(
'DGL_ROLE'
,
'client'
)
==
'server'
:
if
os
.
environ
.
get
(
'DGL_ROLE'
,
'client'
)
==
'server'
:
...
...
python/dgl/distributed/dist_dataloader.py
deleted
100644 → 0
View file @
6557291f
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import
multiprocessing
as
mp
import
time
import
traceback
from
.
import
exit_client
from
.rpc_client
import
get_sampler_pool
from
..
import
backend
as
F
__all__
=
[
"DistDataLoader"
]
def
call_collate_fn
(
next_data
):
"""Call collate function"""
try
:
result
=
DGL_GLOBAL_COLLATE_FN
(
next_data
)
DGL_GLOBAL_MP_QUEUE
.
put
(
result
)
except
Exception
as
e
:
traceback
.
print_exc
()
print
(
e
)
raise
e
return
1
def
init_fn
(
collate_fn
,
queue
,
sig_queue
):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global
DGL_GLOBAL_COLLATE_FN
global
DGL_GLOBAL_MP_QUEUE
global
DGL_SIG_QUEUE
DGL_SIG_QUEUE
=
sig_queue
DGL_GLOBAL_MP_QUEUE
=
queue
DGL_GLOBAL_COLLATE_FN
=
collate_fn
time
.
sleep
(
1
)
return
1
def
_exit
():
exit_client
()
time
.
sleep
(
1
)
class
DistDataLoader
:
"""DGL customized multiprocessing dataloader"""
def
__init__
(
self
,
dataset
,
batch_size
,
shuffle
=
False
,
num_workers
=
1
,
collate_fn
=
None
,
drop_last
=
False
,
queue_size
=
None
):
"""
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
queue_size (int): Size of multiprocessing queue
"""
assert
num_workers
>
0
if
queue_size
is
None
:
queue_size
=
num_workers
*
4
self
.
queue_size
=
queue_size
self
.
batch_size
=
batch_size
self
.
queue_size
=
queue_size
self
.
collate_fn
=
collate_fn
self
.
current_pos
=
0
self
.
num_workers
=
num_workers
self
.
m
=
mp
.
Manager
()
self
.
queue
=
self
.
m
.
Queue
(
maxsize
=
queue_size
)
self
.
sig_queue
=
self
.
m
.
Queue
(
maxsize
=
num_workers
)
self
.
drop_last
=
drop_last
self
.
send_idxs
=
0
self
.
recv_idxs
=
0
self
.
started
=
False
self
.
shuffle
=
shuffle
self
.
pool
,
num_sampler_workers
=
get_sampler_pool
()
if
self
.
pool
is
None
:
ctx
=
mp
.
get_context
(
"spawn"
)
self
.
pool
=
ctx
.
Pool
(
num_workers
)
else
:
assert
num_sampler_workers
==
num_workers
,
"Num workers should be the same"
results
=
[]
for
_
in
range
(
num_workers
):
results
.
append
(
self
.
pool
.
apply_async
(
init_fn
,
args
=
(
collate_fn
,
self
.
queue
,
self
.
sig_queue
)))
time
.
sleep
(
0.1
)
for
res
in
results
:
res
.
get
()
self
.
dataset
=
F
.
tensor
(
dataset
)
self
.
expected_idxs
=
len
(
dataset
)
//
self
.
batch_size
if
not
self
.
drop_last
and
len
(
dataset
)
%
self
.
batch_size
!=
0
:
self
.
expected_idxs
+=
1
def
__next__
(
self
):
if
not
self
.
started
:
for
_
in
range
(
self
.
queue_size
):
self
.
_request_next_batch
()
self
.
_request_next_batch
()
if
self
.
recv_idxs
<
self
.
expected_idxs
:
result
=
self
.
queue
.
get
(
timeout
=
9999
)
self
.
recv_idxs
+=
1
return
result
else
:
self
.
recv_idxs
=
0
self
.
current_pos
=
0
raise
StopIteration
def
__iter__
(
self
):
if
self
.
shuffle
:
self
.
dataset
=
F
.
rand_shuffle
(
self
.
dataset
)
return
self
def
_request_next_batch
(
self
):
next_data
=
self
.
_next_data
()
if
next_data
is
None
:
return
None
else
:
async_result
=
self
.
pool
.
apply_async
(
call_collate_fn
,
args
=
(
next_data
,
))
self
.
send_idxs
+=
1
return
async_result
def
_next_data
(
self
):
if
self
.
current_pos
==
len
(
self
.
dataset
):
return
None
end_pos
=
0
if
self
.
current_pos
+
self
.
batch_size
>
len
(
self
.
dataset
):
if
self
.
drop_last
:
return
None
else
:
end_pos
=
len
(
self
.
dataset
)
else
:
end_pos
=
self
.
current_pos
+
self
.
batch_size
ret
=
self
.
dataset
[
self
.
current_pos
:
end_pos
]
self
.
current_pos
=
end_pos
return
ret
def
close
(
self
):
"""Finalize the connection with server and close pool"""
for
_
in
range
(
self
.
num_workers
):
self
.
pool
.
apply_async
(
_exit
)
time
.
sleep
(
0.1
)
self
.
pool
.
close
()
python/dgl/distributed/dist_graph.py
View file @
71cf7865
...
@@ -17,6 +17,7 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
...
@@ -17,6 +17,7 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from
.graph_partition_book
import
NODE_PART_POLICY
,
EDGE_PART_POLICY
from
.graph_partition_book
import
NODE_PART_POLICY
,
EDGE_PART_POLICY
from
.shared_mem_utils
import
_to_shared_mem
,
_get_ndata_path
,
_get_edata_path
,
DTYPE_DICT
from
.shared_mem_utils
import
_to_shared_mem
,
_get_ndata_path
,
_get_edata_path
,
DTYPE_DICT
from
.
import
rpc
from
.
import
rpc
from
.rpc_client
import
connect_to_server
from
.server_state
import
ServerState
from
.server_state
import
ServerState
from
.rpc_server
import
start_server
from
.rpc_server
import
start_server
from
.dist_tensor
import
DistTensor
,
_get_data_name
from
.dist_tensor
import
DistTensor
,
_get_data_name
...
@@ -295,9 +296,6 @@ class DistGraph:
...
@@ -295,9 +296,6 @@ class DistGraph:
The partition config file. It's used in the standalone mode.
The partition config file. It's used in the standalone mode.
'''
'''
def
__init__
(
self
,
ip_config
,
graph_name
,
gpb
=
None
,
conf_file
=
None
):
def
__init__
(
self
,
ip_config
,
graph_name
,
gpb
=
None
,
conf_file
=
None
):
self
.
ip_config
=
ip_config
self
.
graph_name
=
graph_name
self
.
_gpb_input
=
gpb
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
==
'standalone'
:
if
os
.
environ
.
get
(
'DGL_DIST_MODE'
,
'standalone'
)
==
'standalone'
:
assert
conf_file
is
not
None
,
\
assert
conf_file
is
not
None
,
\
'When running in the standalone model, the partition config file is required'
'When running in the standalone model, the partition config file is required'
...
@@ -315,7 +313,18 @@ class DistGraph:
...
@@ -315,7 +313,18 @@ class DistGraph:
self
.
_client
.
add_data
(
_get_data_name
(
name
,
EDGE_PART_POLICY
),
edge_feats
[
name
])
self
.
_client
.
add_data
(
_get_data_name
(
name
,
EDGE_PART_POLICY
),
edge_feats
[
name
])
rpc
.
set_num_client
(
1
)
rpc
.
set_num_client
(
1
)
else
:
else
:
self
.
_init
()
connect_to_server
(
ip_config
=
ip_config
)
self
.
_client
=
KVClient
(
ip_config
)
g
=
_get_graph_from_shared_mem
(
graph_name
)
if
g
is
not
None
:
self
.
_g
=
as_heterograph
(
g
)
else
:
self
.
_g
=
None
self
.
_gpb
=
get_shared_mem_partition_book
(
graph_name
,
self
.
_g
)
if
self
.
_gpb
is
None
:
self
.
_gpb
=
gpb
self
.
_client
.
barrier
()
self
.
_client
.
map_shared_data
(
self
.
_gpb
)
self
.
_ndata
=
NodeDataView
(
self
)
self
.
_ndata
=
NodeDataView
(
self
)
self
.
_edata
=
EdgeDataView
(
self
)
self
.
_edata
=
EdgeDataView
(
self
)
...
@@ -326,34 +335,6 @@ class DistGraph:
...
@@ -326,34 +335,6 @@ class DistGraph:
self
.
_num_nodes
+=
int
(
part_md
[
'num_nodes'
])
self
.
_num_nodes
+=
int
(
part_md
[
'num_nodes'
])
self
.
_num_edges
+=
int
(
part_md
[
'num_edges'
])
self
.
_num_edges
+=
int
(
part_md
[
'num_edges'
])
def
_init
(
self
):
ip_config
,
graph_name
,
gpb
=
self
.
ip_config
,
self
.
graph_name
,
self
.
_gpb_input
self
.
_client
=
KVClient
(
ip_config
)
g
=
_get_graph_from_shared_mem
(
graph_name
)
if
g
is
not
None
:
self
.
_g
=
as_heterograph
(
g
)
else
:
self
.
_g
=
None
self
.
_gpb
=
get_shared_mem_partition_book
(
graph_name
,
self
.
_g
)
if
self
.
_gpb
is
None
:
self
.
_gpb
=
gpb
self
.
_client
.
map_shared_data
(
self
.
_gpb
)
def
__getstate__
(
self
):
return
self
.
ip_config
,
self
.
graph_name
,
self
.
_gpb_input
def
__setstate__
(
self
,
state
):
self
.
ip_config
,
self
.
graph_name
,
self
.
_gpb_input
=
state
self
.
_init
()
self
.
_ndata
=
NodeDataView
(
self
)
self
.
_edata
=
EdgeDataView
(
self
)
self
.
_num_nodes
=
0
self
.
_num_edges
=
0
for
part_md
in
self
.
_gpb
.
metadata
():
self
.
_num_nodes
+=
int
(
part_md
[
'num_nodes'
])
self
.
_num_edges
+=
int
(
part_md
[
'num_edges'
])
@
property
@
property
def
local_partition
(
self
):
def
local_partition
(
self
):
''' Return the local partition on the client
''' Return the local partition on the client
...
...
python/dgl/distributed/graph_services.py
View file @
71cf7865
...
@@ -132,7 +132,8 @@ def merge_graphs(res_list, num_nodes):
...
@@ -132,7 +132,8 @@ def merge_graphs(res_list, num_nodes):
src_tensor
=
res_list
[
0
].
global_src
src_tensor
=
res_list
[
0
].
global_src
dst_tensor
=
res_list
[
0
].
global_dst
dst_tensor
=
res_list
[
0
].
global_dst
eid_tensor
=
res_list
[
0
].
global_eids
eid_tensor
=
res_list
[
0
].
global_eids
g
=
graph
((
src_tensor
,
dst_tensor
),
num_nodes
=
num_nodes
)
g
=
graph
((
src_tensor
,
dst_tensor
),
restrict_format
=
'coo'
,
num_nodes
=
num_nodes
)
g
.
edata
[
EID
]
=
eid_tensor
g
.
edata
[
EID
]
=
eid_tensor
return
g
return
g
...
...
python/dgl/distributed/kvstore.py
View file @
71cf7865
...
@@ -1001,6 +1001,7 @@ class KVClient(object):
...
@@ -1001,6 +1001,7 @@ class KVClient(object):
Store the partition information
Store the partition information
"""
"""
# Get shared data from server side
# Get shared data from server side
self
.
barrier
()
request
=
GetSharedDataRequest
(
GET_SHARED_MSG
)
request
=
GetSharedDataRequest
(
GET_SHARED_MSG
)
rpc
.
send_request
(
self
.
_main_server_id
,
request
)
rpc
.
send_request
(
self
.
_main_server_id
,
request
)
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
...
@@ -1042,6 +1043,7 @@ class KVClient(object):
...
@@ -1042,6 +1043,7 @@ class KVClient(object):
response
=
rpc
.
recv_response
()
response
=
rpc
.
recv_response
()
assert
response
.
msg
==
SEND_META_TO_BACKUP_MSG
assert
response
.
msg
==
SEND_META_TO_BACKUP_MSG
self
.
_data_name_list
.
add
(
name
)
self
.
_data_name_list
.
add
(
name
)
self
.
barrier
()
def
data_name_list
(
self
):
def
data_name_list
(
self
):
"""Get all the data name"""
"""Get all the data name"""
...
...
python/dgl/distributed/rpc.py
View file @
71cf7865
...
@@ -1024,6 +1024,7 @@ class ShutDownRequest(Request):
...
@@ -1024,6 +1024,7 @@ class ShutDownRequest(Request):
self
.
client_id
=
state
self
.
client_id
=
state
def
process_request
(
self
,
server_state
):
def
process_request
(
self
,
server_state
):
assert
self
.
client_id
==
0
finalize_server
()
finalize_server
()
return
'exit'
return
'exit'
...
...
python/dgl/distributed/rpc_client.py
View file @
71cf7865
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
os
import
os
import
socket
import
socket
import
multiprocessing
as
mp
import
atexit
import
atexit
from
.
import
rpc
from
.
import
rpc
...
@@ -181,14 +180,12 @@ def finalize_client():
...
@@ -181,14 +180,12 @@ def finalize_client():
"""Release resources of this client."""
"""Release resources of this client."""
rpc
.
finalize_sender
()
rpc
.
finalize_sender
()
rpc
.
finalize_receiver
()
rpc
.
finalize_receiver
()
if
SAMPLER_POOL
is
not
None
:
SAMPLER_POOL
.
close
()
SAMPLER_POOL
.
join
()
global
INITIALIZED
global
INITIALIZED
INITIALIZED
=
False
INITIALIZED
=
False
def
shutdown_servers
():
def
shutdown_servers
():
"""Issue commands to remote servers to shut them down.
"""Issue commands to remote servers to shut them down.
Raises
Raises
------
------
ConnectionError : If anything wrong with the connection.
ConnectionError : If anything wrong with the connection.
...
@@ -198,31 +195,6 @@ def shutdown_servers():
...
@@ -198,31 +195,6 @@ def shutdown_servers():
for
server_id
in
range
(
rpc
.
get_num_server
()):
for
server_id
in
range
(
rpc
.
get_num_server
()):
rpc
.
send_request
(
server_id
,
req
)
rpc
.
send_request
(
server_id
,
req
)
SAMPLER_POOL
=
None
NUM_SAMPLER_WORKERS
=
0
def
_close
():
"""Finalize client and close servers when finished"""
rpc
.
finalize_sender
()
rpc
.
finalize_receiver
()
def
_init_rpc
(
ip_config
,
max_queue_size
,
net_type
):
connect_to_server
(
ip_config
,
max_queue_size
,
net_type
)
def
get_sampler_pool
():
"""Return the sampler pool and num_workers"""
return
SAMPLER_POOL
,
NUM_SAMPLER_WORKERS
def
init_rpc
(
ip_config
,
num_workers
,
max_queue_size
=
MAX_QUEUE_SIZE
,
net_type
=
'socket'
):
"""Init rpc service"""
ctx
=
mp
.
get_context
(
"spawn"
)
global
SAMPLER_POOL
global
NUM_SAMPLER_WORKERS
SAMPLER_POOL
=
ctx
.
Pool
(
num_workers
,
initializer
=
_init_rpc
,
initargs
=
(
ip_config
,
max_queue_size
,
net_type
))
NUM_SAMPLER_WORKERS
=
num_workers
connect_to_server
(
ip_config
,
max_queue_size
,
net_type
)
def
exit_client
():
def
exit_client
():
"""Register exit callback.
"""Register exit callback.
"""
"""
...
...
tests/distributed/test_mp_dataloader.py
deleted
100644 → 0
View file @
6557291f
import
dgl
import
unittest
import
os
from
dgl.data
import
CitationGraphDataset
from
dgl.distributed
import
sample_neighbors
from
dgl.distributed
import
partition_graph
,
load_partition
,
load_partition_book
import
sys
import
multiprocessing
as
mp
import
numpy
as
np
import
time
from
utils
import
get_local_usable_addr
from
pathlib
import
Path
from
dgl.distributed
import
DistGraphServer
,
DistGraph
,
DistDataLoader
import
pytest
class
NeighborSampler
(
object
):
def
__init__
(
self
,
g
,
fanouts
,
sample_neighbors
):
self
.
g
=
g
self
.
fanouts
=
fanouts
self
.
sample_neighbors
=
sample_neighbors
def
sample_blocks
(
self
,
seeds
):
import
torch
as
th
seeds
=
th
.
LongTensor
(
np
.
asarray
(
seeds
))
blocks
=
[]
for
fanout
in
self
.
fanouts
:
# For each seed node, sample ``fanout`` neighbors.
frontier
=
self
.
sample_neighbors
(
self
.
g
,
seeds
,
fanout
,
replace
=
True
)
# Then we compact the frontier into a bipartite graph for message passing.
block
=
dgl
.
to_block
(
frontier
,
seeds
)
# Obtain the seed nodes for next layer.
seeds
=
block
.
srcdata
[
dgl
.
NID
]
blocks
.
insert
(
0
,
block
)
return
blocks
def
start_server
(
rank
,
tmpdir
,
disable_shared_mem
,
num_clients
):
import
dgl
print
(
'server: #clients='
+
str
(
num_clients
))
g
=
DistGraphServer
(
rank
,
"mp_ip_config.txt"
,
num_clients
,
tmpdir
/
'test_sampling.json'
,
disable_shared_mem
=
disable_shared_mem
)
g
.
start
()
def
start_client
(
rank
,
tmpdir
,
disable_shared_mem
,
num_workers
):
import
dgl
import
torch
as
th
os
.
environ
[
'DGL_DIST_MODE'
]
=
'distributed'
dgl
.
distributed
.
init_rpc
(
"mp_ip_config.txt"
,
num_workers
=
4
)
gpb
=
None
if
disable_shared_mem
:
_
,
_
,
_
,
gpb
,
_
=
load_partition
(
tmpdir
/
'test_sampling.json'
,
rank
)
train_nid
=
th
.
arange
(
202
)
dist_graph
=
DistGraph
(
"mp_ip_config.txt"
,
"test_mp"
,
gpb
=
gpb
)
# Create sampler
sampler
=
NeighborSampler
(
dist_graph
,
[
5
,
10
],
dgl
.
distributed
.
sample_neighbors
)
# Create PyTorch DataLoader for constructing blocks
dataloader
=
DistDataLoader
(
dataset
=
train_nid
.
numpy
(),
batch_size
=
32
,
collate_fn
=
sampler
.
sample_blocks
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
4
)
dist_graph
.
_init
()
for
epoch
in
range
(
3
):
for
idx
,
blocks
in
enumerate
(
dataloader
):
print
(
blocks
)
print
(
blocks
[
1
].
edges
())
print
(
idx
)
dataloader
.
close
()
dgl
.
distributed
.
exit_client
()
def
main
(
tmpdir
,
num_server
):
ip_config
=
open
(
"mp_ip_config.txt"
,
"w"
)
for
_
in
range
(
num_server
):
ip_config
.
write
(
'{} 1
\n
'
.
format
(
get_local_usable_addr
()))
ip_config
.
close
()
g
=
CitationGraphDataset
(
"cora"
)[
0
]
g
.
readonly
()
print
(
g
.
idtype
)
num_parts
=
num_server
num_hops
=
1
partition_graph
(
g
,
'test_sampling'
,
num_parts
,
tmpdir
,
num_hops
=
num_hops
,
part_method
=
'metis'
,
reshuffle
=
False
)
num_workers
=
4
pserver_list
=
[]
ctx
=
mp
.
get_context
(
'spawn'
)
for
i
in
range
(
num_server
):
p
=
ctx
.
Process
(
target
=
start_server
,
args
=
(
i
,
tmpdir
,
num_server
>
1
,
num_workers
+
1
))
p
.
start
()
time
.
sleep
(
1
)
pserver_list
.
append
(
p
)
time
.
sleep
(
3
)
sampled_graph
=
start_client
(
0
,
tmpdir
,
num_server
>
1
,
num_workers
)
for
p
in
pserver_list
:
p
.
join
()
# Wait non shared memory graph store
@
unittest
.
skipIf
(
os
.
name
==
'nt'
,
reason
=
'Do not support windows yet'
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
==
'tensorflow'
,
reason
=
'Not support tensorflow for now'
)
def
test_dist_dataloader
(
tmpdir
):
main
(
Path
(
tmpdir
),
3
)
if
__name__
==
"__main__"
:
import
tempfile
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
main
(
Path
(
tmpdirname
),
3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment