Unverified Commit bcf92f6c authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Bugfix] Distributed training can not work with dgl.dataloading.negative_sampler (#3215)



* Fix dist negative data loader bug

* upd

* Fix
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 4d9beb3b
...@@ -707,6 +707,65 @@ class DistGraph: ...@@ -707,6 +707,65 @@ class DistGraph:
""" """
return self._canonical_etypes return self._canonical_etypes
def to_canonical_etype(self, etype):
"""Convert an edge type to the corresponding canonical edge type in the graph.
A canonical edge type is a string triplet ``(str, str, str)``
for source node type, edge type and destination node type.
The function expects the given edge type name can uniquely identify a canonical edge
type. DGL will raise error if this is not the case.
Parameters
----------
etype : str or (str, str, str)
If :attr:`etype` is an edge type (str), it returns the corresponding canonical edge
type in the graph. If :attr:`etype` is already a canonical edge type,
it directly returns the input unchanged.
Returns
-------
(str, str, str)
The canonical edge type corresponding to the edge type.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> g = DistGraph("test")
>>> g.canonical_etypes
[('user', 'follows', 'user'),
('user', 'follows', 'game'),
('user', 'plays', 'game')]
>>> g.to_canonical_etype('plays')
('user', 'plays', 'game')
>>> g.to_canonical_etype(('user', 'plays', 'game'))
('user', 'plays', 'game')
See Also
--------
canonical_etypes
"""
if etype is None:
if len(self.etypes) != 1:
raise DGLError('Edge type name must be specified if there are more than one '
'edge types.')
etype = self.etypes[0]
if isinstance(etype, tuple):
return etype
else:
ret = self._etype2canonical.get(etype, None)
if ret is None:
raise DGLError('Edge type "{}" does not exist.'.format(etype))
if len(ret) != 3:
raise DGLError('Edge type "{}" is ambiguous. Please use canonical edge type '
'in the form of (srctype, etype, dsttype)'.format(etype))
return ret
def get_ntype_id(self, ntype): def get_ntype_id(self, ntype):
"""Return the ID of the given node type. """Return the ID of the given node type.
...@@ -1016,6 +1075,9 @@ class DistGraph: ...@@ -1016,6 +1075,9 @@ class DistGraph:
gpb = self.get_partition_book() gpb = self.get_partition_book()
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
# if etype is a canonical edge type (str, str, str), extract the edge type
if len(etype) == 3:
etype = etype[1]
edges = gpb.map_to_homo_eid(edges, etype) edges = gpb.map_to_homo_eid(edges, etype)
src, dst = dist_find_edges(self, edges) src, dst = dist_find_edges(self, edges)
if len(gpb.ntypes) > 1: if len(gpb.ntypes) > 1:
...@@ -1071,10 +1133,13 @@ class DistGraph: ...@@ -1071,10 +1133,13 @@ class DistGraph:
subg[self._etype2canonical[etype]] = self.find_edges(edges[etype], etype) subg[self._etype2canonical[etype]] = self.find_edges(edges[etype], etype)
num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes} num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes}
subg = dgl_heterograph(subg, num_nodes_dict=num_nodes) subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)
for etype in edges:
subg.edges[etype].data[EID] = edges[etype]
else: else:
assert len(self.etypes) == 1 assert len(self.etypes) == 1
subg = self.find_edges(edges) subg = self.find_edges(edges)
subg = dgl_graph(subg, num_nodes=self.number_of_nodes()) subg = dgl_graph(subg, num_nodes=self.number_of_nodes())
subg.edata[EID] = edges
if relabel_nodes: if relabel_nodes:
subg = compact_graphs(subg) subg = compact_graphs(subg)
......
...@@ -124,6 +124,95 @@ def test_standalone(tmpdir): ...@@ -124,6 +124,95 @@ def test_standalone(tmpdir):
print(e) print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
def start_dist_neg_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, groundtruth_g):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt")
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_edges_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
train_eid = th.arange(num_edges_to_sample)
else:
train_eid = {dist_graph.etypes[0]: th.arange(num_edges_to_sample)}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
num_negs = 5
sampler = dgl.dataloading.MultiLayerNeighborSampler([5,10])
negative_sampler=dgl.dataloading.negative_sampler.Uniform(num_negs)
dataloader = dgl.dataloading.EdgeDataLoader(dist_graph,
train_eid,
sampler,
batch_size=batch_size,
negative_sampler=negative_sampler,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for _ in range(2):
for _, (_, pos_graph, neg_graph, blocks) in zip(range(0, num_edges_to_sample, batch_size), dataloader):
block = blocks[-1]
for src_type, etype, dst_type in block.canonical_etypes:
o_src, o_dst = block.edges(etype=etype)
src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
src_nodes_id = orig_nid[src_type][src_nodes_id]
dst_nodes_id = orig_nid[dst_type][dst_nodes_id]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id, etype=etype)
assert np.all(F.asnumpy(has_edges))
assert np.all(F.asnumpy(block.dstnodes[dst_type].data[dgl.NID]) == F.asnumpy(pos_graph.nodes[dst_type].data[dgl.NID]))
assert np.all(F.asnumpy(block.dstnodes[dst_type].data[dgl.NID]) == F.asnumpy(neg_graph.nodes[dst_type].data[dgl.NID]))
assert pos_graph.num_edges() * num_negs == neg_graph.num_edges()
del dataloader
dgl.distributed.exit_client() # this is needed since there's two test here in one process
def check_neg_dataloader(g, tmpdir, num_server, num_workers):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close()
num_parts = num_server
num_hops = 1
orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis',
reshuffle=True, return_mapping=True)
if not isinstance(orig_nid, dict):
orig_nid = {g.ntypes[0]: orig_nid}
if not isinstance(orig_eid, dict):
orig_eid = {g.etypes[0]: orig_eid}
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer_list = []
p = ctx.Process(target=start_dist_neg_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, g))
p.start()
time.sleep(1)
ptrainer_list.append(p)
for p in pserver_list:
p.join()
for p in ptrainer_list:
p.join()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
...@@ -343,12 +432,24 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -343,12 +432,24 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
g = create_random_hetero() g = create_random_hetero()
check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type) check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support")
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
def test_neg_dataloader(tmpdir, num_server, num_workers):
g = CitationGraphDataset("cora")[0]
check_neg_dataloader(g, tmpdir, num_server, num_workers)
g = create_random_hetero()
check_neg_dataloader(g, tmpdir, num_server, num_workers)
if __name__ == "__main__": if __name__ == "__main__":
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
test_standalone(Path(tmpdirname)) test_standalone(Path(tmpdirname))
test_dataloader(Path(tmpdirname), 3, 4, 'node') test_dataloader(Path(tmpdirname), 3, 4, 'node')
test_dataloader(Path(tmpdirname), 3, 4, 'edge') test_dataloader(Path(tmpdirname), 3, 4, 'edge')
test_neg_dataloader(Path(tmpdirname), 3, 4)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, True) test_dist_dataloader(Path(tmpdirname), 3, 0, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, True) test_dist_dataloader(Path(tmpdirname), 3, 4, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, False) test_dist_dataloader(Path(tmpdirname), 3, 0, True, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment