Unverified Commit 3ebdee77 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] sample with graphbolt on homograph via DistNodeDataLoader (#7108)

parent 7f7967b3
...@@ -192,9 +192,11 @@ class NeighborSampler(BlockSampler): ...@@ -192,9 +192,11 @@ class NeighborSampler(BlockSampler):
output_device=self.output_device, output_device=self.output_device,
exclude_edges=exclude_eids, exclude_edges=exclude_eids,
) )
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes) block = to_block(frontier, seed_nodes)
block.edata[EID] = eid # If sampled from graphbolt-backed DistGraph, `EID` may not be in
# the block.
if EID in frontier.edata.keys():
block.edata[EID] = frontier.edata[EID]
seed_nodes = block.srcdata[NID] seed_nodes = block.srcdata[NID]
blocks.insert(0, block) blocks.insert(0, block)
......
...@@ -1406,7 +1406,12 @@ class DistGraph: ...@@ -1406,7 +1406,12 @@ class DistGraph:
) )
else: else:
frontier = graph_services.sample_neighbors( frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace, prob=prob self,
seed_nodes,
fanout,
replace=replace,
prob=prob,
use_graphbolt=self._use_graphbolt,
) )
return frontier return frontier
......
...@@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers): ...@@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 4]) @pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("drop_last", [False, True]) @pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("use_graphbolt", [False, True]) @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True]) @pytest.mark.parametrize("return_eids", [False, True])
...@@ -429,6 +429,8 @@ def start_node_dataloader( ...@@ -429,6 +429,8 @@ def start_node_dataloader(
orig_nid, orig_nid,
orig_eid, orig_eid,
groundtruth_g, groundtruth_g,
use_graphbolt=False,
return_eids=False,
): ):
dgl.distributed.initialize(ip_config) dgl.distributed.initialize(ip_config)
gpb = None gpb = None
...@@ -437,7 +439,12 @@ def start_node_dataloader( ...@@ -437,7 +439,12 @@ def start_node_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank) _, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202 num_nodes_to_sample = 202
batch_size = 32 batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config) dist_graph = DistGraph(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes) assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes) assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1: if len(dist_graph.etypes) == 1:
...@@ -459,6 +466,9 @@ def start_node_dataloader( ...@@ -459,6 +466,9 @@ def start_node_dataloader(
] ]
) # test int for hetero ) # test int for hetero
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
# We need to test creating DistDataLoader multiple times. # We need to test creating DistDataLoader multiple times.
for i in range(2): for i in range(2):
# Create DataLoader for constructing blocks # Create DataLoader for constructing blocks
...@@ -472,7 +482,7 @@ def start_node_dataloader( ...@@ -472,7 +482,7 @@ def start_node_dataloader(
num_workers=num_workers, num_workers=num_workers,
) )
for epoch in range(2): for _ in range(2):
for idx, (_, _, blocks) in zip( for idx, (_, _, blocks) in zip(
range(0, num_nodes_to_sample, batch_size), dataloader range(0, num_nodes_to_sample, batch_size), dataloader
): ):
...@@ -487,6 +497,16 @@ def start_node_dataloader( ...@@ -487,6 +497,16 @@ def start_node_dataloader(
src_nodes_id, dst_nodes_id, etype=etype src_nodes_id, dst_nodes_id, etype=etype
) )
assert np.all(F.asnumpy(has_edges)) assert np.all(F.asnumpy(has_edges))
if use_graphbolt and not return_eids:
continue
eids = orig_eid[etype][block.edata[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
)
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
del dataloader del dataloader
# this is needed since there's two test here in one process # this is needed since there's two test here in one process
dgl.distributed.exit_client() dgl.distributed.exit_client()
...@@ -509,7 +529,7 @@ def start_edge_dataloader( ...@@ -509,7 +529,7 @@ def start_edge_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank) _, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_edges_to_sample = 202 num_edges_to_sample = 202
batch_size = 32 batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config) dist_graph = DistGraph("test_sampling", gpb=gpb, part_config=part_config)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes) assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes) assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1: if len(dist_graph.etypes) == 1:
...@@ -561,7 +581,14 @@ def start_edge_dataloader( ...@@ -561,7 +581,14 @@ def start_edge_dataloader(
dgl.distributed.exit_client() dgl.distributed.exit_client()
def check_dataloader(g, num_server, num_workers, dataloader_type): def check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=False,
return_eids=False,
):
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt" ip_config = "ip_config.txt"
generate_ip_config(ip_config, num_server, num_server) generate_ip_config(ip_config, num_server, num_server)
...@@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type): ...@@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
part_config = os.path.join(test_dir, "test_sampling.json") part_config = os.path.join(test_dir, "test_sampling.json")
if not isinstance(orig_nid, dict): if not isinstance(orig_nid, dict):
...@@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type): ...@@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
part_config, part_config,
num_server > 1, num_server > 1,
num_workers + 1, num_workers + 1,
use_graphbolt,
), ),
) )
p.start() p.start()
...@@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type): ...@@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
orig_nid, orig_nid,
orig_eid, orig_eid,
g, g,
use_graphbolt,
return_eids,
), ),
) )
p.start() p.start()
...@@ -663,14 +695,35 @@ def create_random_hetero(): ...@@ -663,14 +695,35 @@ def create_random_hetero():
return g return g
@unittest.skip(reason="Skip due to glitch in CI") @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_server", [3]) @pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"]) @pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(num_server, num_workers, dataloader_type): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dataloader_homograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs() reset_envs()
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
check_dataloader(g, num_server, num_workers, dataloader_type) check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader_heterograph(num_server, num_workers, dataloader_type):
reset_envs()
g = create_random_hetero() g = create_random_hetero()
check_dataloader(g, num_server, num_workers, dataloader_type) check_dataloader(g, num_server, num_workers, dataloader_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment