Unverified Commit 8e6cbd62 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] sample with graphbolt on heterograph via DistNodeDataLoader (#7112)

parent 924c5669
...@@ -622,18 +622,7 @@ class DistGraph: ...@@ -622,18 +622,7 @@ class DistGraph:
self._init_ndata_store() self._init_ndata_store()
self._init_edata_store() self._init_edata_store()
self._init_metadata()
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {
etype: i for i, etype in enumerate(self.canonical_etypes)
}
def _init(self, gpb): def _init(self, gpb):
self._client = get_kvstore() self._client = get_kvstore()
...@@ -698,6 +687,19 @@ class DistGraph: ...@@ -698,6 +687,19 @@ class DistGraph:
else: else:
self._edata_store[etype] = data self._edata_store[etype] = data
def _init_metadata(self):
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {
etype: i for i, etype in enumerate(self.canonical_etypes)
}
def __getstate__(self): def __getstate__(self):
return self.graph_name, self._gpb, self._use_graphbolt return self.graph_name, self._gpb, self._use_graphbolt
...@@ -707,11 +709,7 @@ class DistGraph: ...@@ -707,11 +709,7 @@ class DistGraph:
self._init_ndata_store() self._init_ndata_store()
self._init_edata_store() self._init_edata_store()
self._num_nodes = 0 self._init_metadata()
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
@property @property
def local_partition(self): def local_partition(self):
...@@ -1403,6 +1401,7 @@ class DistGraph: ...@@ -1403,6 +1401,7 @@ class DistGraph:
replace=replace, replace=replace,
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
prob=prob, prob=prob,
use_graphbolt=self._use_graphbolt,
) )
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
......
...@@ -487,22 +487,23 @@ def start_node_dataloader( ...@@ -487,22 +487,23 @@ def start_node_dataloader(
range(0, num_nodes_to_sample, batch_size), dataloader range(0, num_nodes_to_sample, batch_size), dataloader
): ):
block = blocks[-1] block = blocks[-1]
for src_type, etype, dst_type in block.canonical_etypes: for c_etype in block.canonical_etypes:
o_src, o_dst = block.edges(etype=etype) src_type, _, dst_type = c_etype
o_src, o_dst = block.edges(etype=c_etype)
src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src] src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst] dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
src_nodes_id = orig_nid[src_type][src_nodes_id] src_nodes_id = orig_nid[src_type][src_nodes_id]
dst_nodes_id = orig_nid[dst_type][dst_nodes_id] dst_nodes_id = orig_nid[dst_type][dst_nodes_id]
has_edges = groundtruth_g.has_edges_between( has_edges = groundtruth_g.has_edges_between(
src_nodes_id, dst_nodes_id, etype=etype src_nodes_id, dst_nodes_id, etype=c_etype
) )
assert np.all(F.asnumpy(has_edges)) assert np.all(F.asnumpy(has_edges))
if use_graphbolt and not return_eids: if use_graphbolt and not return_eids:
continue continue
eids = orig_eid[etype][block.edata[dgl.EID]] eids = orig_eid[c_etype][block.edges[c_etype].data[dgl.EID]]
expected_eids = groundtruth_g.edge_ids( expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id src_nodes_id, dst_nodes_id, etype=c_etype
) )
assert th.equal( assert th.equal(
eids, expected_eids eids, expected_eids
...@@ -610,7 +611,7 @@ def check_dataloader( ...@@ -610,7 +611,7 @@ def check_dataloader(
if not isinstance(orig_nid, dict): if not isinstance(orig_nid, dict):
orig_nid = {g.ntypes[0]: orig_nid} orig_nid = {g.ntypes[0]: orig_nid}
if not isinstance(orig_eid, dict): if not isinstance(orig_eid, dict):
orig_eid = {g.etypes[0]: orig_eid} orig_eid = {g.canonical_etypes[0]: orig_eid}
pserver_list = [] pserver_list = []
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
...@@ -718,14 +719,27 @@ def test_dataloader_homograph( ...@@ -718,14 +719,27 @@ def test_dataloader_homograph(
) )
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1]) @pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) @pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader_heterograph(num_server, num_workers, dataloader_type): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dataloader_heterograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs() reset_envs()
g = create_random_hetero() g = create_random_hetero()
check_dataloader(g, num_server, num_workers, dataloader_type) check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@unittest.skip(reason="Skip due to glitch in CI") @unittest.skip(reason="Skip due to glitch in CI")
...@@ -740,10 +754,18 @@ def test_neg_dataloader(num_server, num_workers): ...@@ -740,10 +754,18 @@ def test_neg_dataloader(num_server, num_workers):
def start_multiple_dataloaders( def start_multiple_dataloaders(
ip_config, part_config, graph_name, orig_g, num_dataloaders, dataloader_type ip_config,
part_config,
graph_name,
orig_g,
num_dataloaders,
dataloader_type,
use_graphbolt,
): ):
dgl.distributed.initialize(ip_config) dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config) dist_g = dgl.distributed.DistGraph(
graph_name, part_config=part_config, use_graphbolt=use_graphbolt
)
if dataloader_type == "node": if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes()) train_ids = th.arange(orig_g.num_nodes())
batch_size = orig_g.num_nodes() // 100 batch_size = orig_g.num_nodes() // 100
...@@ -777,13 +799,17 @@ def start_multiple_dataloaders( ...@@ -777,13 +799,17 @@ def start_multiple_dataloaders(
dgl.distributed.exit_client() dgl.distributed.exit_client()
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_dataloaders", [1, 4]) @pytest.mark.parametrize("num_dataloaders", [1, 4])
@pytest.mark.parametrize("num_workers", [0, 1, 4]) @pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) @pytest.mark.parametrize("dataloader_type", ["node", "edge"])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_multiple_dist_dataloaders( def test_multiple_dist_dataloaders(
num_dataloaders, num_workers, dataloader_type num_dataloaders, num_workers, dataloader_type, use_graphbolt, return_eids
): ):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers) os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
...@@ -794,8 +820,15 @@ def test_multiple_dist_dataloaders( ...@@ -794,8 +820,15 @@ def test_multiple_dist_dataloaders(
generate_ip_config(ip_config, num_parts, num_servers) generate_ip_config(ip_config, num_parts, num_servers)
orig_g = dgl.rand_graph(1000, 10000) orig_g = dgl.rand_graph(1000, 10000)
graph_name = "test" graph_name = "test_multiple_dataloaders"
partition_graph(orig_g, graph_name, num_parts, test_dir) partition_graph(
orig_g,
graph_name,
num_parts,
test_dir,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
part_config = os.path.join(test_dir, f"{graph_name}.json") part_config = os.path.join(test_dir, f"{graph_name}.json")
p_servers = [] p_servers = []
...@@ -809,6 +842,7 @@ def test_multiple_dist_dataloaders( ...@@ -809,6 +842,7 @@ def test_multiple_dist_dataloaders(
part_config, part_config,
num_servers > 1, num_servers > 1,
num_workers + 1, num_workers + 1,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -824,6 +858,7 @@ def test_multiple_dist_dataloaders( ...@@ -824,6 +858,7 @@ def test_multiple_dist_dataloaders(
orig_g, orig_g,
num_dataloaders, num_dataloaders,
dataloader_type, dataloader_type,
use_graphbolt,
), ),
) )
p_client.start() p_client.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment