Unverified Commit 8ab27b05 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable replacement sampling with GraphBolt API (#7202)

parent b0982feb
......@@ -130,13 +130,10 @@ def _sample_neighbors_graphbolt(
nodes = nodes.to(dtype=g.indices.dtype)
# 2. Perform sampling.
# [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now.
# [Rui][TODO] `prob` is not tested yet. Skip for now.
assert (
prob is None
), "DistGraphBolt does not support sampling with probability."
assert (
not replace
), "DistGraphBolt does not support sampling with replacement."
# Sanity checks.
assert isinstance(
......@@ -148,7 +145,9 @@ def _sample_neighbors_graphbolt(
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
subgraph = g._sample_neighbors(
nodes, fanout, replace=replace, return_eids=return_eids
)
# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
......
......@@ -80,6 +80,7 @@ def start_sample_client_shuffle(
use_graphbolt=False,
return_eids=False,
node_id_dtype=None,
replace=False,
):
os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None
......@@ -93,6 +94,7 @@ def start_sample_client_shuffle(
dist_graph,
torch.tensor([0, 10, 99, 66, 1024, 2008], dtype=node_id_dtype),
3,
replace=replace,
use_graphbolt=use_graphbolt,
)
assert sampled_graph.idtype == dist_graph.idtype
......@@ -102,6 +104,7 @@ def start_sample_client_shuffle(
dgl.ETYPE not in sampled_graph.edata
), "Etype should not be in homogeneous sampled graph."
src, dst = sampled_graph.edges()
sampled_in_degrees = sampled_graph.in_degrees(dst)
src = orig_nid[src]
dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes()
......@@ -114,6 +117,14 @@ def start_sample_client_shuffle(
eids = g.edge_ids(src, dst)
eids1 = orig_eid[sampled_graph.edata[dgl.EID]]
assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids))
# Verify replace argument.
orig_in_degrees = g.in_degrees(dst)
if replace:
assert torch.all(
(sampled_in_degrees == 3) | (sampled_in_degrees == orig_in_degrees)
)
else:
assert torch.all(sampled_in_degrees <= 3)
def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None):
......@@ -408,6 +419,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt=False,
return_eids=False,
node_id_dtype=None,
replace=False,
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -463,6 +475,7 @@ def check_rpc_sampling_shuffle(
use_graphbolt,
return_eids,
node_id_dtype,
replace,
),
)
p.start()
......@@ -483,6 +496,7 @@ def start_hetero_sample_client(
nodes,
use_graphbolt=False,
return_eids=False,
replace=False,
):
gpb = None
if disable_shared_mem:
......@@ -503,7 +517,7 @@ def start_hetero_sample_client(
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_neighbors(
dist_graph, nodes, 3, use_graphbolt=use_graphbolt
dist_graph, nodes, 3, replace=replace, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes)
if not use_graphbolt or return_eids:
......@@ -573,7 +587,7 @@ def start_hetero_etype_sample_client(
def check_rpc_hetero_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
tmpdir, num_server, use_graphbolt=False, return_eids=False, replace=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -619,6 +633,7 @@ def check_rpc_hetero_sampling_shuffle(
nodes=nodes,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
replace=replace,
)
for p in pserver_list:
p.join()
......@@ -1247,8 +1262,9 @@ def check_rpc_bipartite_etype_sampling_shuffle(
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
@pytest.mark.parametrize("node_id_dtype", [torch.int64])
@pytest.mark.parametrize("replace", [False, True])
def test_rpc_sampling_shuffle(
num_server, use_graphbolt, return_eids, node_id_dtype
num_server, use_graphbolt, return_eids, node_id_dtype, replace
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
......@@ -1259,13 +1275,17 @@ def test_rpc_sampling_shuffle(
use_graphbolt=use_graphbolt,
return_eids=return_eids,
node_id_dtype=node_id_dtype,
replace=replace,
)
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt,", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
@pytest.mark.parametrize("replace", [False, True])
def test_rpc_hetero_sampling_shuffle(
num_server, use_graphbolt, return_eids, replace
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -1274,6 +1294,7 @@ def test_rpc_hetero_sampling_shuffle(num_server, use_graphbolt, return_eids):
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
replace=replace,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment