Unverified Commit 6735a3ae authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable sample etype neighbors on heterograph (#7095)

parent 3ebdee77
......@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
if isinstance(fanout, int):
fanout = torch.LongTensor([fanout])
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."
return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
......@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
return func(*args, **kwargs)
def _sample_etype_neighbors(
def _sample_etype_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
etype_offset,
fan_out,
edge_dir,
prob,
replace,
edge_dir="in",
prob=None,
replace=False,
etype_offset=None,
etype_sorted=False,
):
"""Sample from local partition.
......@@ -255,6 +253,8 @@ def _sample_etype_neighbors(
The sampled results are stored in three vectors that store source nodes, destination nodes
and edge IDs.
"""
assert etype_offset is not None, "The etype offset is not provided."
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
......@@ -278,6 +278,43 @@ def _sample_etype_neighbors(
return LocalSampledGraph(global_src, global_dst, global_eids)
def _sample_etype_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling etype neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt
if use_graphbolt
else _sample_etype_neighbors_dgl
)
if use_graphbolt:
# GraphBolt does not require `etype_offset` and `etype_sorted`.
kwargs.pop("etype_offset", None)
kwargs.pop("etype_sorted", None)
return func(*args, **kwargs)
def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition.
......@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request):
prob=None,
replace=False,
etype_sorted=True,
use_graphbolt=False,
):
self.seed_nodes = nodes
self.edge_dir = edge_dir
......@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request):
self.replace = replace
self.fan_out = fan_out
self.etype_sorted = etype_sorted
self.use_graphbolt = use_graphbolt
def __setstate__(self, state):
(
......@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request):
self.replace,
self.fan_out,
self.etype_sorted,
self.use_graphbolt,
) = state
def __getstate__(self):
......@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request):
self.replace,
self.fan_out,
self.etype_sorted,
self.use_graphbolt,
)
def process_request(self, server_state):
......@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request):
else:
probs = None
res = _sample_etype_neighbors(
self.use_graphbolt,
local_g,
partition_book,
self.seed_nodes,
etype_offset,
self.fan_out,
self.edge_dir,
probs,
self.replace,
self.etype_sorted,
edge_dir=self.edge_dir,
prob=probs,
replace=self.replace,
etype_offset=etype_offset,
etype_sorted=self.etype_sorted,
)
return SubgraphResponse(
res.global_src,
......@@ -772,6 +814,7 @@ def sample_etype_neighbors(
prob=None,
replace=False,
etype_sorted=True,
use_graphbolt=False,
):
"""Sample from the neighbors of the given nodes from a distributed graph.
......@@ -825,6 +868,8 @@ def sample_etype_neighbors(
neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional
Indicates whether etypes are sorted.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns
-------
......@@ -882,6 +927,7 @@ def sample_etype_neighbors(
prob=_prob,
replace=replace,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
)
def local_access(local_g, partition_book, local_nids):
......@@ -897,14 +943,15 @@ def sample_etype_neighbors(
for etype in g.canonical_etypes
]
return _sample_etype_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
etype_offset,
fanout,
edge_dir,
_prob,
replace,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
etype_offset=etype_offset,
etype_sorted=etype_sorted,
)
......
......@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client(
fanout=3,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
etype_sorted=False,
use_graphbolt=False,
return_eids=False,
):
gpb = None
if disable_shared_mem:
......@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data
if dist_graph.local_partition is not None:
if (not use_graphbolt) and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes())
......@@ -533,10 +537,18 @@ def start_hetero_etype_sample_client(
if gpb is None:
gpb = dist_graph.get_partition_book()
try:
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_etype_neighbors(
dist_graph, nodes, fanout, etype_sorted=etype_sorted
dist_graph,
nodes,
fanout,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
)
block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0:
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e:
print(traceback.format_exc())
......@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle(
def check_rpc_hetero_etype_sampling_shuffle(
tmpdir, num_server, graph_formats=None
tmpdir,
num_server,
graph_formats=None,
use_graphbolt=False,
return_eids=False,
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
part_method="metis",
return_mapping=True,
graph_formats=graph_formats,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle(
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
fanout,
nodes={"n3": [0, 10, 99, 66, 124, 208]},
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
......@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle(
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
# Check the node Ids and edge Ids.
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
def check_rpc_hetero_etype_sampling_empty_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(dense=True, empty=True)
......@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_etype_sample_client(
0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids}
0,
tmpdir,
num_server > 1,
fanout,
nodes={"n3": empty_nids},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
......@@ -848,7 +897,13 @@ def start_bipartite_sample_client(
def start_bipartite_etype_sample_client(
rank, tmpdir, disable_shared_mem, fanout=3, nodes={}
rank,
tmpdir,
disable_shared_mem,
fanout=3,
nodes={},
use_graphbolt=False,
return_eids=False,
):
gpb = None
if disable_shared_mem:
......@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client(
tmpdir / "test_sampling.json", rank
)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data
if dist_graph.local_partition is not None:
if not use_graphbolt and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes())
......@@ -872,9 +929,12 @@ def start_bipartite_etype_sample_client(
if gpb is None:
gpb = dist_graph.get_partition_book()
sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout)
sampled_graph = sample_etype_neighbors(
dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0:
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client()
return block, gpb
......@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle(
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
def check_rpc_bipartite_etype_sampling_empty(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_bipartite_etype_sample_client(
0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]}
block, _ = start_bipartite_etype_sample_client(
0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
......@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
def check_rpc_bipartite_etype_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
......@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
pserver_list = []
......@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"),
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
)
p.start()
time.sleep(1)
......@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_etype_sample_client(
0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]}
0,
tmpdir,
num_server > 1,
fanout,
nodes={"game": nids, "user": [0]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
print("Done sampling")
for p in pserver_list:
......@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
# Check the node Ids and edge Ids.
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
......@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle(
@pytest.mark.parametrize(
"graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
)
def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
@pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=True,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_empty_shuffle(
Path(tmpdirname), num_server
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
......@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_etype_sampling_empty(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_etype_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
def check_standalone_sampling(tmpdir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment