"...text-generation-inference.git" did not exist on "1869ee2f5767db42aafa20712736aea99ded2bfc"
Unverified Commit 3ebdee77 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] sample with graphbolt on homograph via DistNodeDataLoader (#7108)

parent 7f7967b3
......@@ -192,9 +192,11 @@ class NeighborSampler(BlockSampler):
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
# If sampled from graphbolt-backed DistGraph, `EID` may not be in
# the block.
if EID in frontier.edata.keys():
block.edata[EID] = frontier.edata[EID]
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)
......
......@@ -1406,7 +1406,12 @@ class DistGraph:
)
else:
frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace, prob=prob
self,
seed_nodes,
fanout,
replace=replace,
prob=prob,
use_graphbolt=self._use_graphbolt,
)
return frontier
......
......@@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers):
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
......@@ -429,6 +429,8 @@ def start_node_dataloader(
orig_nid,
orig_eid,
groundtruth_g,
use_graphbolt=False,
return_eids=False,
):
dgl.distributed.initialize(ip_config)
gpb = None
......@@ -437,7 +439,12 @@ def start_node_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config)
dist_graph = DistGraph(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
......@@ -459,6 +466,9 @@ def start_node_dataloader(
]
) # test int for hetero
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
......@@ -472,7 +482,7 @@ def start_node_dataloader(
num_workers=num_workers,
)
for epoch in range(2):
for _ in range(2):
for idx, (_, _, blocks) in zip(
range(0, num_nodes_to_sample, batch_size), dataloader
):
......@@ -487,6 +497,16 @@ def start_node_dataloader(
src_nodes_id, dst_nodes_id, etype=etype
)
assert np.all(F.asnumpy(has_edges))
if use_graphbolt and not return_eids:
continue
eids = orig_eid[etype][block.edata[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
)
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
del dataloader
# this is needed since there's two test here in one process
dgl.distributed.exit_client()
......@@ -509,7 +529,7 @@ def start_edge_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_edges_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config)
dist_graph = DistGraph("test_sampling", gpb=gpb, part_config=part_config)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
......@@ -561,7 +581,14 @@ def start_edge_dataloader(
dgl.distributed.exit_client()
def check_dataloader(g, num_server, num_workers, dataloader_type):
def check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=False,
return_eids=False,
):
with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt"
generate_ip_config(ip_config, num_server, num_server)
......@@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
part_config = os.path.join(test_dir, "test_sampling.json")
if not isinstance(orig_nid, dict):
......@@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
part_config,
num_server > 1,
num_workers + 1,
use_graphbolt,
),
)
p.start()
......@@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
orig_nid,
orig_eid,
g,
use_graphbolt,
return_eids,
),
)
p.start()
......@@ -663,14 +695,35 @@ def create_random_hetero():
return g
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(num_server, num_workers, dataloader_type):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dataloader_homograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs()
g = CitationGraphDataset("cora")[0]
check_dataloader(g, num_server, num_workers, dataloader_type)
check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader_heterograph(num_server, num_workers, dataloader_type):
reset_envs()
g = create_random_hetero()
check_dataloader(g, num_server, num_workers, dataloader_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment