Unverified Commit 8e6cbd62 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] sample with graphbolt on heterograph via DistNodeDataLoader (#7112)

parent 924c5669
......@@ -622,18 +622,7 @@ class DistGraph:
self._init_ndata_store()
self._init_edata_store()
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {
etype: i for i, etype in enumerate(self.canonical_etypes)
}
self._init_metadata()
def _init(self, gpb):
self._client = get_kvstore()
......@@ -698,6 +687,19 @@ class DistGraph:
else:
self._edata_store[etype] = data
def _init_metadata(self):
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
# When we store node/edge types in a list, they are stored in the order of type IDs.
self._ntype_map = {ntype: i for i, ntype in enumerate(self.ntypes)}
self._etype_map = {
etype: i for i, etype in enumerate(self.canonical_etypes)
}
def __getstate__(self):
return self.graph_name, self._gpb, self._use_graphbolt
......@@ -707,11 +709,7 @@ class DistGraph:
self._init_ndata_store()
self._init_edata_store()
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md["num_nodes"])
self._num_edges += int(part_md["num_edges"])
self._init_metadata()
@property
def local_partition(self):
......@@ -1403,6 +1401,7 @@ class DistGraph:
replace=replace,
etype_sorted=etype_sorted,
prob=prob,
use_graphbolt=self._use_graphbolt,
)
else:
frontier = graph_services.sample_neighbors(
......
......@@ -487,22 +487,23 @@ def start_node_dataloader(
range(0, num_nodes_to_sample, batch_size), dataloader
):
block = blocks[-1]
for src_type, etype, dst_type in block.canonical_etypes:
o_src, o_dst = block.edges(etype=etype)
for c_etype in block.canonical_etypes:
src_type, _, dst_type = c_etype
o_src, o_dst = block.edges(etype=c_etype)
src_nodes_id = block.srcnodes[src_type].data[dgl.NID][o_src]
dst_nodes_id = block.dstnodes[dst_type].data[dgl.NID][o_dst]
src_nodes_id = orig_nid[src_type][src_nodes_id]
dst_nodes_id = orig_nid[dst_type][dst_nodes_id]
has_edges = groundtruth_g.has_edges_between(
src_nodes_id, dst_nodes_id, etype=etype
src_nodes_id, dst_nodes_id, etype=c_etype
)
assert np.all(F.asnumpy(has_edges))
if use_graphbolt and not return_eids:
continue
eids = orig_eid[etype][block.edata[dgl.EID]]
eids = orig_eid[c_etype][block.edges[c_etype].data[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
src_nodes_id, dst_nodes_id, etype=c_etype
)
assert th.equal(
eids, expected_eids
......@@ -610,7 +611,7 @@ def check_dataloader(
if not isinstance(orig_nid, dict):
orig_nid = {g.ntypes[0]: orig_nid}
if not isinstance(orig_eid, dict):
orig_eid = {g.etypes[0]: orig_eid}
orig_eid = {g.canonical_etypes[0]: orig_eid}
pserver_list = []
ctx = mp.get_context("spawn")
......@@ -718,14 +719,27 @@ def test_dataloader_homograph(
)
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader_heterograph(num_server, num_workers, dataloader_type):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dataloader_heterograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs()
g = create_random_hetero()
check_dataloader(g, num_server, num_workers, dataloader_type)
check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@unittest.skip(reason="Skip due to glitch in CI")
......@@ -740,10 +754,18 @@ def test_neg_dataloader(num_server, num_workers):
def start_multiple_dataloaders(
ip_config, part_config, graph_name, orig_g, num_dataloaders, dataloader_type
ip_config,
part_config,
graph_name,
orig_g,
num_dataloaders,
dataloader_type,
use_graphbolt,
):
dgl.distributed.initialize(ip_config)
dist_g = dgl.distributed.DistGraph(graph_name, part_config=part_config)
dist_g = dgl.distributed.DistGraph(
graph_name, part_config=part_config, use_graphbolt=use_graphbolt
)
if dataloader_type == "node":
train_ids = th.arange(orig_g.num_nodes())
batch_size = orig_g.num_nodes() // 100
......@@ -777,13 +799,17 @@ def start_multiple_dataloaders(
dgl.distributed.exit_client()
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_dataloaders", [1, 4])
@pytest.mark.parametrize("num_workers", [0, 1, 4])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_multiple_dist_dataloaders(
num_dataloaders, num_workers, dataloader_type
num_dataloaders, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
......@@ -794,8 +820,15 @@ def test_multiple_dist_dataloaders(
generate_ip_config(ip_config, num_parts, num_servers)
orig_g = dgl.rand_graph(1000, 10000)
graph_name = "test"
partition_graph(orig_g, graph_name, num_parts, test_dir)
graph_name = "test_multiple_dataloaders"
partition_graph(
orig_g,
graph_name,
num_parts,
test_dir,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")
p_servers = []
......@@ -809,6 +842,7 @@ def test_multiple_dist_dataloaders(
part_config,
num_servers > 1,
num_workers + 1,
use_graphbolt,
),
)
p.start()
......@@ -824,6 +858,7 @@ def test_multiple_dist_dataloaders(
orig_g,
num_dataloaders,
dataloader_type,
use_graphbolt,
),
)
p_client.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment