Unverified Commit 8ab27b05 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable replacement sampling with GraphBolt API (#7202)

parent b0982feb
...@@ -130,13 +130,10 @@ def _sample_neighbors_graphbolt( ...@@ -130,13 +130,10 @@ def _sample_neighbors_graphbolt(
nodes = nodes.to(dtype=g.indices.dtype) nodes = nodes.to(dtype=g.indices.dtype)
# 2. Perform sampling. # 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now. # [Rui][TODO] `prob` is not tested yet. Skip for now.
assert ( assert (
prob is None prob is None
), "DistGraphBolt does not support sampling with probability." ), "DistGraphBolt does not support sampling with probability."
assert (
not replace
), "DistGraphBolt does not support sampling with replacement."
# Sanity checks. # Sanity checks.
assert isinstance( assert isinstance(
...@@ -148,7 +145,9 @@ def _sample_neighbors_graphbolt( ...@@ -148,7 +145,9 @@ def _sample_neighbors_graphbolt(
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout." assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
return_eids = g.edge_attributes is not None and EID in g.edge_attributes return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids) subgraph = g._sample_neighbors(
nodes, fanout, replace=replace, return_eids=return_eids
)
# 3. Map local node IDs to global node IDs. # 3. Map local node IDs to global node IDs.
local_src = subgraph.indices local_src = subgraph.indices
......
...@@ -80,6 +80,7 @@ def start_sample_client_shuffle( ...@@ -80,6 +80,7 @@ def start_sample_client_shuffle(
use_graphbolt=False, use_graphbolt=False,
return_eids=False, return_eids=False,
node_id_dtype=None, node_id_dtype=None,
replace=False,
): ):
os.environ["DGL_GROUP_ID"] = str(group_id) os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None gpb = None
...@@ -93,6 +94,7 @@ def start_sample_client_shuffle( ...@@ -93,6 +94,7 @@ def start_sample_client_shuffle(
dist_graph, dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype), torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
3, 3,
replace=replace,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
) )
assert sampled_graph.idtype == dist_graph.idtype assert sampled_graph.idtype == dist_graph.idtype
...@@ -102,6 +104,7 @@ def start_sample_client_shuffle( ...@@ -102,6 +104,7 @@ def start_sample_client_shuffle(
dgl.ETYPE not in sampled_graph.edata dgl.ETYPE not in sampled_graph.edata
), "Etype should not be in homogeneous sampled graph." ), "Etype should not be in homogeneous sampled graph."
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
sampled_in_degrees = sampled_graph.in_degrees(dst)
src = orig_nid[src] src = orig_nid[src]
dst = orig_nid[dst] dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes() assert sampled_graph.num_nodes() == g.num_nodes()
...@@ -114,6 +117,14 @@ def start_sample_client_shuffle( ...@@ -114,6 +117,14 @@ def start_sample_client_shuffle(
eids = g.edge_ids(src, dst) eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]] eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids)) assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
# Verify replace argument.
orig_in_degrees = g.in_degrees(dst)
if replace:
assert torch.all(
(sampled_in_degrees == 3) | (sampled_in_degrees == orig_in_degrees)
)
else:
assert torch.all(sampled_in_degrees <= 3)
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None): def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
...@@ -408,6 +419,7 @@ def check_rpc_sampling_shuffle( ...@@ -408,6 +419,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt=False, use_graphbolt=False,
return_eids=False, return_eids=False,
node_id_dtype=None, node_id_dtype=None,
replace=False,
): ):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -463,6 +475,7 @@ def check_rpc_sampling_shuffle( ...@@ -463,6 +475,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt, use_graphbolt,
return_eids, return_eids,
node_id_dtype, node_id_dtype,
replace,
), ),
) )
p.start() p.start()
...@@ -483,6 +496,7 @@ def start_hetero_sample_client( ...@@ -483,6 +496,7 @@ def start_hetero_sample_client(
nodes, nodes,
use_graphbolt=False, use_graphbolt=False,
return_eids=False, return_eids=False,
replace=False,
): ):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -503,7 +517,7 @@ def start_hetero_sample_client( ...@@ -503,7 +517,7 @@ def start_hetero_sample_client(
# Enable santity check in distributed sampling. # Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1" os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors( sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt dist_graph, nodes, 3, replace=replace, use_graphbolt=use_graphbolt
) )
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
if not use_graphbolt or return_eids: if not use_graphbolt or return_eids:
...@@ -573,7 +587,7 @@ def start_hetero_etype_sample_client( ...@@ -573,7 +587,7 @@ def start_hetero_etype_sample_client(
def check_rpc_hetero_sampling_shuffle( def check_rpc_hetero_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False tmpdir, num_server, use_graphbolt=False, return_eids=False, replace=False
): ):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -619,6 +633,7 @@ def check_rpc_hetero_sampling_shuffle( ...@@ -619,6 +633,7 @@ def check_rpc_hetero_sampling_shuffle(
nodes=nodes, nodes=nodes,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
return_eids=return_eids, return_eids=return_eids,
replace=replace,
) )
for p in pserver_list: for p in pserver_list:
p.join() p.join()
...@@ -1247,8 +1262,9 @@ def check_rpc_bipartite_etype_sampling_shuffle( ...@@ -1247,8 +1262,9 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@pytest.mark.parametrize("use_graphbolt", [False, True]) @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True]) @pytest.mark.parametrize("return_eids", [False, True])
@pytest.mark.parametrize("node_id_dtype", [torch.int64]) @pytest.mark.parametrize("node_id_dtype", [torch.int64])
@pytest.mark.parametrize("replace", [False, True])
def test_rpc_sampling_shuffle( def test_rpc_sampling_shuffle(
num_server, use_graphbolt, return_eids, node_id_dtype num_server, use_graphbolt, return_eids, node_id_dtype, replace
): ):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
...@@ -1259,13 +1275,17 @@ def test_rpc_sampling_shuffle( ...@@ -1259,13 +1275,17 @@ def test_rpc_sampling_shuffle(
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
return_eids=return_eids, return_eids=return_eids,
node_id_dtype=node_id_dtype, node_id_dtype=node_id_dtype,
replace=replace,
) )
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt,", [False, True]) @pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True]) @pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids): @pytest.mark.parametrize("replace", [False, True])
def test_rpc_hetero_sampling_shuffle(
num_server, use_graphbolt, return_eids, replace
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -1274,6 +1294,7 @@ def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids): ...@@ -1274,6 +1294,7 @@ def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
num_server, num_server,
use_graphbolt=use_graphbolt, use_graphbolt=use_graphbolt,
return_eids=return_eids, return_eids=return_eids,
replace=replace,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment