Unverified Commit 763bd39f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] sample with graphbolt on homograph via DistDataLoader (#7098)

parent 870d8d02
...@@ -22,10 +22,19 @@ from utils import generate_ip_config, reset_envs ...@@ -22,10 +22,19 @@ from utils import generate_ip_config, reset_envs
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors): def __init__(
self,
g,
fanouts,
sample_neighbors,
use_graphbolt=False,
return_eids=False,
):
self.g = g self.g = g
self.fanouts = fanouts self.fanouts = fanouts
self.sample_neighbors = sample_neighbors self.sample_neighbors = sample_neighbors
self.use_graphbolt = use_graphbolt
self.return_eids = return_eids
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
import torch as th import torch as th
...@@ -35,13 +44,16 @@ class NeighborSampler(object): ...@@ -35,13 +44,16 @@ class NeighborSampler(object):
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors( frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True self.g, seeds, fanout, use_graphbolt=self.use_graphbolt
) )
# Then we compact the frontier into a bipartite graph for # Then we compact the frontier into a bipartite graph for
# message passing. # message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID] seeds = block.srcdata[dgl.NID]
if frontier.num_edges() > 0:
if not self.use_graphbolt or self.return_eids:
block.edata[dgl.EID] = frontier.edata[dgl.EID]
blocks.insert(0, block) blocks.insert(0, block)
return blocks return blocks
...@@ -53,6 +65,7 @@ def start_server( ...@@ -53,6 +65,7 @@ def start_server(
part_config, part_config,
disable_shared_mem, disable_shared_mem,
num_clients, num_clients,
use_graphbolt=False,
): ):
print("server: #clients=" + str(num_clients)) print("server: #clients=" + str(num_clients))
g = DistGraphServer( g = DistGraphServer(
...@@ -63,6 +76,7 @@ def start_server( ...@@ -63,6 +76,7 @@ def start_server(
part_config, part_config,
disable_shared_mem=disable_shared_mem, disable_shared_mem=disable_shared_mem,
graph_format=["csc", "coo"], graph_format=["csc", "coo"],
use_graphbolt=use_graphbolt,
) )
g.start() g.start()
...@@ -75,30 +89,36 @@ def start_dist_dataloader( ...@@ -75,30 +89,36 @@ def start_dist_dataloader(
drop_last, drop_last,
orig_nid, orig_nid,
orig_eid, orig_eid,
group_id=0, use_graphbolt=False,
return_eids=False,
): ):
import dgl
import torch as th
os.environ["DGL_GROUP_ID"] = str(group_id)
dgl.distributed.initialize(ip_config) dgl.distributed.initialize(ip_config)
gpb = None gpb = None
disable_shared_mem = num_server > 0 disable_shared_mem = num_server > 1
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(part_config, rank) _, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202 num_nodes_to_sample = 202
batch_size = 32 batch_size = 32
train_nid = th.arange(num_nodes_to_sample) train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config) dist_graph = DistGraph(
"test_sampling",
for i in range(num_server): gpb=gpb,
part, _, _, _, _, _, _ = load_partition(part_config, i) part_config=part_config,
use_graphbolt=use_graphbolt,
)
# Create sampler # Create sampler
sampler = NeighborSampler( sampler = NeighborSampler(
dist_graph, [5, 10], dgl.distributed.sample_neighbors dist_graph,
[5, 10],
dgl.distributed.sample_neighbors,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
# We need to test creating DistDataLoader multiple times. # We need to test creating DistDataLoader multiple times.
for i in range(2): for i in range(2):
# Create DataLoader for constructing blocks # Create DataLoader for constructing blocks
...@@ -113,7 +133,7 @@ def start_dist_dataloader( ...@@ -113,7 +133,7 @@ def start_dist_dataloader(
groundtruth_g = CitationGraphDataset("cora")[0] groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = [] max_nid = []
for epoch in range(2): for _ in range(2):
for idx, blocks in zip( for idx, blocks in zip(
range(0, num_nodes_to_sample, batch_size), dataloader range(0, num_nodes_to_sample, batch_size), dataloader
): ):
...@@ -129,6 +149,16 @@ def start_dist_dataloader( ...@@ -129,6 +149,16 @@ def start_dist_dataloader(
src_nodes_id, dst_nodes_id src_nodes_id, dst_nodes_id
) )
assert np.all(F.asnumpy(has_edges)) assert np.all(F.asnumpy(has_edges))
if use_graphbolt and not return_eids:
continue
eids = orig_eid[block.edata[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
)
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
if drop_last: if drop_last:
assert ( assert (
np.max(max_nid) np.max(max_nid)
...@@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers): ...@@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers):
assert p.exitcode == 0 assert p.exitcode == 0
@unittest.skip(reason="Skip due to glitch in CI") @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4]) @pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("num_groups", [1]) @pytest.mark.parametrize("use_graphbolt", [False, True])
def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): @pytest.mark.parametrize("return_eids", [False, True])
def test_dist_dataloader(
num_server, num_workers, drop_last, use_graphbolt, return_eids
):
reset_envs() reset_envs()
# No multiple partitions on single machine for os.environ["DGL_DIST_MODE"] = "distributed"
# multiple client groups in case of race condition. os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
if num_groups > 1:
num_server = 1
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt" ip_config = "ip_config.txt"
generate_ip_config(ip_config, num_server, num_server) generate_ip_config(ip_config, num_server, num_server)
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
...@@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): ...@@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
part_config = os.path.join(test_dir, "test_sampling.json") part_config = os.path.join(test_dir, "test_sampling.json")
...@@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): ...@@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config, part_config,
num_server > 1, num_server > 1,
num_workers + 1, num_workers + 1,
use_graphbolt,
), ),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
ptrainer_list = [] ptrainer_list = []
num_trainers = 1 num_trainers = 1
for trainer_id in range(num_trainers): for trainer_id in range(num_trainers):
for group_id in range(num_groups): p = ctx.Process(
p = ctx.Process( target=start_dist_dataloader,
target=start_dist_dataloader, args=(
args=( trainer_id,
trainer_id, ip_config,
ip_config, part_config,
part_config, num_server,
num_server, drop_last,
drop_last, orig_nid,
orig_nid, orig_eid,
orig_eid, use_graphbolt,
group_id, return_eids,
), ),
) )
p.start() p.start()
time.sleep( time.sleep(1) # avoid race condition when instantiating DistGraph
1 ptrainer_list.append(p)
) # avoid race condition when instantiating DistGraph
ptrainer_list.append(p)
for p in ptrainer_list: for p in ptrainer_list:
p.join() p.join()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment