Unverified Commit 1ad46fd0 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] support distributed sampling in NodeDataLoader (#2012)

* fix.

* doesn't support EdgeDataLoader

* add test.

* fix.

* fix test.

* catch errors in the test.

* fix.

* fix error message

* add print.

* join the server processes.

* update test.

* fix test.
parent fe121c63
......@@ -2,8 +2,18 @@
import inspect
from torch.utils.data import DataLoader
from ..dataloader import NodeCollator, EdgeCollator
class NodeDataLoader(DataLoader):
from ...distributed import DistGraph
from ...distributed import DistDataLoader
def _remove_kwargs_dist(kwargs):
if 'num_workers' in kwargs:
del kwargs['num_workers']
if 'pin_memory' in kwargs:
del kwargs['pin_memory']
print('Distributed DataLoader does not support pin_memory')
return kwargs
class NodeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list
of blocks as computation dependency of the said minibatch.
......@@ -42,8 +52,22 @@ class NodeDataLoader(DataLoader):
else:
dataloader_kwargs[k] = v
self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
super().__init__(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
if isinstance(g, DistGraph):
_remove_kwargs_dist(dataloader_kwargs)
self.dataloader = DistDataLoader(self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
else:
self.dataloader = DataLoader(self.collator.dataset,
collate_fn=self.collator.collate,
**dataloader_kwargs)
def __next__(self):
return self.dataloader.__next()
def __iter__(self):
return self.dataloader.__iter__()
class EdgeDataLoader(DataLoader):
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list
......@@ -196,5 +220,9 @@ class EdgeDataLoader(DataLoader):
else:
dataloader_kwargs[k] = v
self.collator = EdgeCollator(g, eids, block_sampler, **collator_kwargs)
assert not isinstance(g, DistGraph), \
'EdgeDataLoader does not support DistGraph for now. ' \
+ 'Please use DistDataLoader directly.'
super().__init__(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
......@@ -68,7 +68,8 @@ class DistTensor:
+ 'Please provide a partition policy explicitly.'
part_policy = policy
assert part_policy is not None, \
'Cannot determine the partition policy. Please provide it.'
'Cannot find a right partition policy. Currently, DistTensor only ' \
+ 'supports partition policy associated with nodes or edges.'
self._part_policy = part_policy
......
......@@ -252,7 +252,10 @@ def test_standalone():
dgl.distributed.initialize("kv_ip_config.txt")
dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
try:
check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
except Exception as e:
print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
......
......@@ -27,7 +27,11 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dist_graph = DistGraph("test_sampling", gpb=gpb)
try:
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
except Exception as e:
print(e)
sampled_graph = None
dgl.distributed.exit_client()
return sampled_graph
......@@ -37,7 +41,11 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
_, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt", 1)
dist_graph = DistGraph("test_find_edges", gpb=gpb)
try:
u, v = find_edges(dist_graph, eids)
except Exception as e:
print(e)
u, v = None, None
dgl.distributed.exit_client()
return u, v
......@@ -201,7 +209,11 @@ def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("test_in_subgraph", gpb=gpb)
try:
sampled_graph = dgl.distributed.in_subgraph(dist_graph, nodes)
except Exception as e:
print(e)
sampled_graph = None
dgl.distributed.exit_client()
return sampled_graph
......
......@@ -45,7 +45,7 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
g.start()
def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
def start_dist_dataloader(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
......@@ -88,6 +88,8 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else:
assert np.max(max_nid) == num_nodes_to_sample - 1
del dataloader
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
......@@ -106,7 +108,10 @@ def test_standalone(tmpdir):
num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone'
start_client(0, tmpdir, False, 2, True)
try:
start_dist_dataloader(0, tmpdir, False, 2, True)
except Exception as e:
print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
......@@ -140,12 +145,105 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last):
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
start_client(0, tmpdir, num_server > 1, num_workers, drop_last)
ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server > 1, num_workers, drop_last))
ptrainer.start()
time.sleep(1)
for p in pserver_list:
p.join()
ptrainer.join()
def start_node_dataloader(rank, tmpdir, disable_shared_mem, num_workers):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt", 1, num_workers=num_workers)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_nodes_to_sample = 202
batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
dataloader = dgl.dataloading.NodeDataLoader(
dist_graph,
train_nid,
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []
for epoch in range(2):
for idx, (_, _, blocks) in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges))
max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
del dataloader
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
ptrainer_list = []
if dataloader_type == 'node':
p = ctx.Process(target=start_node_dataloader, args=(
0, tmpdir, num_server > 1, num_workers))
p.start()
time.sleep(1)
ptrainer_list.append(p)
for p in pserver_list:
p.join()
for p in ptrainer_list:
p.join()
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
test_dataloader(Path(tmpdirname), 3, 4, 'node')
test_standalone(Path(tmpdirname))
test_dist_dataloader(Path(tmpdirname), 3, 0, True)
test_dist_dataloader(Path(tmpdirname), 3, 4, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment