Unverified Commit cb4a83a2 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Enable distributed EdgeDataLoader (#3192)



* make heterogeneous find_edges

* add distributed EdgeDataLoader.

* fix.

* fix a bug.

* fix bugs.

* add tests on distributed heterogeneous graph sampling.

* fix.
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
parent c3921599
...@@ -776,35 +776,49 @@ class EdgeDataLoader: ...@@ -776,35 +776,49 @@ class EdgeDataLoader:
else: else:
dataloader_kwargs[k] = v dataloader_kwargs[k] = v
assert not isinstance(g, DistGraph), \ if isinstance(g, DistGraph):
'EdgeDataLoader does not support DistGraph for now. ' \ if device is None:
+ 'Please use DistDataLoader directly.' # for the distributed case default to the CPU
device = 'cpu'
if device is None: assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
# default to the same device the graph is on # Distributed DataLoader currently does not support heterogeneous graphs
device = th.device(g.device) # and does not copy features. Fallback to normal solution
self.collator = EdgeCollator(g, eids, block_sampler, **collator_kwargs)
# if the sampler supports it, tell it to output to the _remove_kwargs_dist(dataloader_kwargs)
# specified device self.dataloader = DistDataLoader(self.collator.dataset,
num_workers = dataloader_kwargs.get('num_workers', 0) collate_fn=self.collator.collate,
if callable(getattr(block_sampler, "set_output_context", None)) and num_workers == 0: **dataloader_kwargs)
block_sampler.set_output_context(to_dgl_context(device)) self.is_distributed = True
else:
self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs) if device is None:
self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \ # default to the same device the graph is on
_init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed) device = th.device(g.device)
self.use_ddp = use_ddp
self.device = device # if the sampler supports it, tell it to output to the
# specified device
num_workers = dataloader_kwargs.get('num_workers', 0)
if callable(getattr(block_sampler, "set_output_context", None)) and num_workers == 0:
block_sampler.set_output_context(to_dgl_context(device))
# Precompute the CSR and CSC representations so each subprocess does not self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
# duplicate. self.use_scalar_batcher, self.scalar_batcher, self.dataloader, self.dist_sampler = \
if num_workers > 0: _init_dataloader(self.collator, device, dataloader_kwargs, use_ddp, ddp_seed)
g.create_formats_() self.use_ddp = use_ddp
self.is_distributed = False
# Precompute the CSR and CSC representations so each subprocess does not duplicate.
if num_workers > 0:
g.create_formats_()
self.device = device
def __iter__(self): def __iter__(self):
"""Return the iterator of the data loader.""" """Return the iterator of the data loader."""
return _EdgeDataLoaderIter(self) if self.is_distributed:
# Directly use the iterator of DistDataLoader, which doesn't copy features anyway.
return iter(self.dataloader)
else:
return _EdgeDataLoaderIter(self)
def __len__(self): def __len__(self):
"""Return the number of batches of the data loader.""" """Return the number of batches of the data loader."""
......
...@@ -7,6 +7,9 @@ import os ...@@ -7,6 +7,9 @@ import os
import numpy as np import numpy as np
from ..heterograph import DGLHeteroGraph from ..heterograph import DGLHeteroGraph
from ..convert import heterograph as dgl_heterograph
from ..convert import graph as dgl_graph
from ..transform import compact_graphs
from .. import heterograph_index from .. import heterograph_index
from .. import backend as F from .. import backend as F
from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all from ..base import NID, EID, NTYPE, ETYPE, ALL, is_all
...@@ -502,6 +505,12 @@ class DistGraph: ...@@ -502,6 +505,12 @@ class DistGraph:
dst_tid = F.as_scalar(dst_tid) dst_tid = F.as_scalar(dst_tid)
self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id], self._canonical_etypes.append((self.ntypes[src_tid], self.etypes[etype_id],
self.ntypes[dst_tid])) self.ntypes[dst_tid]))
self._etype2canonical = {}
for src_type, etype, dst_type in self._canonical_etypes:
if etype in self._etype2canonical:
self._etype2canonical[etype] = ()
else:
self._etype2canonical[etype] = (src_type, etype, dst_type)
def _init(self): def _init(self):
self._client = get_kvstore() self._client = get_kvstore()
...@@ -514,12 +523,18 @@ class DistGraph: ...@@ -514,12 +523,18 @@ class DistGraph:
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
def __getstate__(self): def __getstate__(self):
return self.graph_name, self._gpb return self.graph_name, self._gpb, self._canonical_etypes
def __setstate__(self, state): def __setstate__(self, state):
self.graph_name, self._gpb_input = state self.graph_name, self._gpb_input, self._canonical_etypes = state
self._init() self._init()
self._etype2canonical = {}
for src_type, etype, dst_type in self._canonical_etypes:
if etype in self._etype2canonical:
self._etype2canonical[etype] = ()
else:
self._etype2canonical[etype] = (src_type, etype, dst_type)
self._ndata_store = {} self._ndata_store = {}
self._edata_store = {} self._edata_store = {}
self._ndata = NodeDataView(self) self._ndata = NodeDataView(self)
...@@ -969,15 +984,25 @@ class DistGraph: ...@@ -969,15 +984,25 @@ class DistGraph:
''' '''
return role.get_global_rank() return role.get_global_rank()
def find_edges(self, edges): def find_edges(self, edges, etype=None):
""" Given an edge ID array, return the source """ Given an edge ID array, return the source
and destination node ID array ``s`` and ``d``. ``s[i]`` and ``d[i]`` and destination node ID array ``s`` and ``d``. ``s[i]`` and ``d[i]``
are source and destination node ID for edge ``eid[i]``. are source and destination node ID for edge ``eid[i]``.
Parameters Parameters
---------- ----------
edges : tensor edges : Int Tensor
The edge ID array. Each element is an ID. The tensor must have the same device type
and ID data type as the graph's.
etype : str or (str, str, str), optional
The type names of the edges. The allowed type name formats are:
* ``(str, str, str)`` for source node type, edge type and destination node type.
* or one ``str`` edge type name if the name can uniquely identify a
triplet format in the graph.
Can be omitted if the graph has only one type of edges.
Returns Returns
------- -------
...@@ -986,8 +1011,75 @@ class DistGraph: ...@@ -986,8 +1011,75 @@ class DistGraph:
tensor tensor
The destination node ID array. The destination node ID array.
""" """
assert len(self.etypes) == 1, 'find_edges does not support heterogeneous graph for now.' if etype is None:
return dist_find_edges(self, edges) assert len(self.etypes) == 1, 'find_edges requires etype for heterogeneous graphs.'
gpb = self.get_partition_book()
if len(gpb.etypes) > 1:
edges = gpb.map_to_homo_eid(edges, etype)
src, dst = dist_find_edges(self, edges)
if len(gpb.ntypes) > 1:
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
return src, dst
def edge_subgraph(self, edges, relabel_nodes=True, store_ids=True):
"""Return a subgraph induced on the given edges.
An edge-induced subgraph is equivalent to creating a new graph using the given
edges. In addition to extracting the subgraph, DGL also copies the features
of the extracted nodes and edges to the resulting graph. The copy is *lazy*
and incurs data movement only when needed.
If the graph is heterogeneous, DGL extracts a subgraph per relation and composes
them as the resulting graph. Thus, the resulting graph has the same set of relations
as the input one.
Parameters
----------
edges : Int Tensor or dict[(str, str, str), Int Tensor]
The edges to form the subgraph. Each element is an edge ID. The tensor must have
the same device type and ID data type as the graph's.
If the graph is homogeneous, one can directly pass an Int Tensor.
Otherwise, the argument must be a dictionary with keys being edge types
and values being the edge IDs in the above formats.
relabel_nodes : bool, optional
If True, it will remove the isolated nodes and relabel the incident nodes in the
extracted subgraph.
store_ids : bool, optional
If True, it will store the raw IDs of the extracted edges in the ``edata`` of the
resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will
also store the raw IDs of the incident nodes in the ``ndata`` of the resulting
graph under name ``dgl.NID``.
Returns
-------
G : DGLGraph
The subgraph.
"""
if isinstance(edges, dict):
# TODO(zhengda) we need to directly generate subgraph of all relations with
# one invocation.
if isinstance(edges, tuple):
subg = {etype: self.find_edges(edges[etype], etype[1]) for etype in edges}
else:
subg = {}
for etype in edges:
assert len(self._etype2canonical[etype]) == 3, \
'the etype in input edges is ambiguous'
subg[self._etype2canonical[etype]] = self.find_edges(edges[etype], etype)
num_nodes = {ntype: self.number_of_nodes(ntype) for ntype in self.ntypes}
subg = dgl_heterograph(subg, num_nodes_dict=num_nodes)
else:
assert len(self.etypes) == 1
subg = self.find_edges(edges)
subg = dgl_graph(subg, num_nodes=self.number_of_nodes())
if relabel_nodes:
subg = compact_graphs(subg)
assert store_ids, 'edge_subgraph always stores original node/edge IDs.'
return subg
def get_partition_book(self): def get_partition_book(self):
"""Get the partition information. """Get the partition information.
......
...@@ -566,6 +566,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -566,6 +566,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
etype_ids, idx = F.sort_1d(etype_ids) etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx) src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx) eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src) _, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst) _, dst = gpb.map_to_per_ntype(dst)
......
...@@ -39,14 +39,14 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -39,14 +39,14 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
dgl.distributed.exit_client() dgl.distributed.exit_client()
return sampled_graph return sampled_graph
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids): def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank) _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_find_edges", gpb=gpb) dist_graph = DistGraph("test_find_edges", gpb=gpb)
try: try:
u, v = dist_graph.find_edges(eids) u, v = dist_graph.find_edges(eids, etype=etype)
except Exception as e: except Exception as e:
print(e) print(e)
u, v = None, None u, v = None, None
...@@ -116,8 +116,9 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server): ...@@ -116,8 +116,9 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
g.readonly() g.readonly()
num_parts = num_server num_parts = num_server
partition_graph(g, 'test_find_edges', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_find_edges', num_parts, tmpdir,
num_hops=1, part_method='metis', reshuffle=True) num_hops=1, part_method='metis',
reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -128,13 +129,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server): ...@@ -128,13 +129,6 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
orig_nid = F.zeros((g.number_of_nodes(),), dtype=F.int64, ctx=F.cpu())
orig_eid = F.zeros((g.number_of_edges(),), dtype=F.int64, ctx=F.cpu())
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_find_edges.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
time.sleep(3) time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges(), size=100)) eids = F.tensor(np.random.randint(g.number_of_edges(), size=100))
u, v = g.find_edges(orig_eid[eids]) u, v = g.find_edges(orig_eid[eids])
...@@ -144,6 +138,50 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server): ...@@ -144,6 +138,50 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
assert F.array_equal(u, du) assert F.array_equal(u, du)
assert F.array_equal(v, dv) assert F.array_equal(v, dv)
def create_random_hetero():
num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020}
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
edges = {}
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
random_state=100)
edges[etype] = (arr.row, arr.col)
return dgl.heterograph(edges, num_nodes)
def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
ip_config = open("rpc_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close()
g = create_random_hetero()
num_parts = num_server
orig_nid, orig_eid = partition_graph(g, 'test_find_edges', num_parts, tmpdir,
num_hops=1, part_method='metis',
reshuffle=True, return_mapping=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(i, tmpdir, num_server > 1,
'test_find_edges', ['csr', 'coo']))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
eids = F.tensor(np.random.randint(g.number_of_edges('r1'), size=100))
u, v = g.find_edges(orig_eid['r1'][eids], etype='r1')
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r1')
du = orig_nid['n1'][du]
dv = orig_nid['n2'][dv]
assert F.array_equal(u, du)
assert F.array_equal(v, dv)
# Wait non shared memory graph store # Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
...@@ -153,6 +191,7 @@ def test_rpc_find_edges_shuffle(num_server): ...@@ -153,6 +191,7 @@ def test_rpc_find_edges_shuffle(num_server):
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), num_server)
check_rpc_find_edges_shuffle(Path(tmpdirname), num_server) check_rpc_find_edges_shuffle(Path(tmpdirname), num_server)
def check_rpc_get_degree_shuffle(tmpdir, num_server): def check_rpc_get_degree_shuffle(tmpdir, num_server):
...@@ -630,6 +669,8 @@ if __name__ == "__main__": ...@@ -630,6 +669,8 @@ if __name__ == "__main__":
check_rpc_get_degree_shuffle(Path(tmpdirname), 2) check_rpc_get_degree_shuffle(Path(tmpdirname), 2)
check_rpc_find_edges_shuffle(Path(tmpdirname), 2) check_rpc_find_edges_shuffle(Path(tmpdirname), 2)
check_rpc_find_edges_shuffle(Path(tmpdirname), 1) check_rpc_find_edges_shuffle(Path(tmpdirname), 1)
check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 1)
check_rpc_hetero_find_edges_shuffle(Path(tmpdirname), 2)
check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2) check_rpc_in_subgraph_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1) check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname), 2) check_rpc_sampling_shuffle(Path(tmpdirname), 2)
......
import dgl import dgl
import unittest import unittest
import os import os
from scipy import sparse as spsp
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book from dgl.distributed import partition_graph, load_partition, load_partition_book
...@@ -166,7 +167,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -166,7 +167,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
p.join() p.join()
ptrainer.join() ptrainer.join()
def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid): def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid, groundtruth_g):
import dgl import dgl
import torch as th import torch as th
dgl.distributed.initialize("mp_ip_config.txt") dgl.distributed.initialize("mp_ip_config.txt")
...@@ -176,8 +177,13 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_ ...@@ -176,8 +177,13 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_nodes_to_sample = 202 num_nodes_to_sample = 202
batch_size = 32 batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
train_nid = th.arange(num_nodes_to_sample)
else:
train_nid = {'n3': th.arange(num_nodes_to_sample)}
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
...@@ -197,44 +203,88 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_ ...@@ -197,44 +203,88 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_
drop_last=False, drop_last=False,
num_workers=num_workers) num_workers=num_workers)
groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []
for epoch in range(2): for epoch in range(2):
for idx, (_, _, blocks) in zip(range(0, num_nodes_to_sample, batch_size), dataloader): for idx, (_, _, blocks) in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1] block = blocks[-1]
o_src, o_dst = block.edges() for src_type, etype, dst_type in block.canonical_etypes:
src_nodes_id = block.srcdata[dgl.NID][o_src] o_src, o_dst = block.edges(etype=etype)
dst_nodes_id = block.dstdata[dgl.NID][o_dst] src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
src_nodes_id = orig_nid[src_nodes_id] dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
dst_nodes_id = orig_nid[dst_nodes_id] src_nodes_id = orig_nid[src_type][src_nodes_id]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id) dst_nodes_id = orig_nid[dst_type][dst_nodes_id]
assert np.all(F.asnumpy(has_edges)) has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id, etype=etype)
max_nid.append(np.max(F.asnumpy(dst_nodes_id))) assert np.all(F.asnumpy(has_edges))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size)) # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
del dataloader del dataloader
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
def start_edge_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid, groundtruth_g):
import dgl
import torch as th
dgl.distributed.initialize("mp_ip_config.txt")
gpb = None
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_edges_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
train_eid = th.arange(num_edges_to_sample)
else:
train_eid = {dist_graph.etypes[0]: th.arange(num_edges_to_sample)}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
# Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') # We need to test creating DistDataLoader multiple times.
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') for i in range(2):
@pytest.mark.parametrize("num_server", [3]) # Create DataLoader for constructing blocks
@pytest.mark.parametrize("num_workers", [0, 4]) dataloader = dgl.dataloading.EdgeDataLoader(
@pytest.mark.parametrize("dataloader_type", ["node"]) dist_graph,
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): train_eid,
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for epoch in range(2):
for idx, (input_nodes, pos_pair_graph, blocks) in zip(range(0, num_edges_to_sample, batch_size), dataloader):
block = blocks[-1]
for src_type, etype, dst_type in block.canonical_etypes:
o_src, o_dst = block.edges(etype=etype)
src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
src_nodes_id = orig_nid[src_type][src_nodes_id]
dst_nodes_id = orig_nid[dst_type][dst_nodes_id]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id, etype=etype)
assert np.all(F.asnumpy(has_edges))
assert np.all(F.asnumpy(block.dstnodes[dst_type].data[dgl.NID]) == F.asnumpy(pos_pair_graph.nodes[dst_type].data[dgl.NID]))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
del dataloader
dgl.distributed.exit_client() # this is needed since there's two test here in one process
def check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type):
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{}\n'.format(get_local_usable_addr())) ip_config.write('{}\n'.format(get_local_usable_addr()))
ip_config.close() ip_config.close()
g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', num_hops=num_hops, part_method='metis',
reshuffle=True, return_mapping=True) reshuffle=True, return_mapping=True)
if not isinstance(orig_nid, dict):
orig_nid = {g.ntypes[0]: orig_nid}
if not isinstance(orig_eid, dict):
orig_eid = {g.etypes[0]: orig_eid}
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -251,7 +301,13 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -251,7 +301,13 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
ptrainer_list = [] ptrainer_list = []
if dataloader_type == 'node': if dataloader_type == 'node':
p = ctx.Process(target=start_node_dataloader, args=( p = ctx.Process(target=start_node_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid)) 0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start()
time.sleep(1)
ptrainer_list.append(p)
elif dataloader_type == 'edge':
p = ctx.Process(target=start_edge_dataloader, args=(
0, tmpdir, num_server, num_workers, orig_nid, orig_eid, g))
p.start() p.start()
time.sleep(1) time.sleep(1)
ptrainer_list.append(p) ptrainer_list.append(p)
...@@ -260,12 +316,40 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -260,12 +316,40 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
for p in ptrainer_list: for p in ptrainer_list:
p.join() p.join()
def create_random_hetero():
num_nodes = {'n1': 10000, 'n2': 10010, 'n3': 10020}
etypes = [('n1', 'r1', 'n2'),
('n1', 'r2', 'n3'),
('n2', 'r3', 'n3')]
edges = {}
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(num_nodes[src_ntype], num_nodes[dst_ntype], density=0.001, format='coo',
random_state=100)
edges[etype] = (arr.row, arr.col)
g = dgl.heterograph(edges, num_nodes)
g.nodes['n1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_nodes('n1')), 1)
g.edges['r1'].data['feat'] = F.unsqueeze(F.arange(0, g.number_of_edges('r1')), 1)
return g
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
g = CitationGraphDataset("cora")[0]
check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type)
g = create_random_hetero()
check_dataloader(g, tmpdir, num_server, num_workers, dataloader_type)
if __name__ == "__main__": if __name__ == "__main__":
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
test_standalone(Path(tmpdirname)) test_standalone(Path(tmpdirname))
test_dataloader(Path(tmpdirname), 3, 4, 'node')
test_dataloader(Path(tmpdirname), 3, 4, 'edge')
test_dist_dataloader(Path(tmpdirname), 3, 0, True, True) test_dist_dataloader(Path(tmpdirname), 3, 0, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, True) test_dist_dataloader(Path(tmpdirname), 3, 4, True, True)
test_dist_dataloader(Path(tmpdirname), 3, 0, True, False) test_dist_dataloader(Path(tmpdirname), 3, 0, True, False)
test_dist_dataloader(Path(tmpdirname), 3, 4, True, False) test_dist_dataloader(Path(tmpdirname), 3, 4, True, False)
test_dataloader(Path(tmpdirname), 3, 4, 'node')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment