"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5afbcce176cd4e8ec08f43ee9fae2d6562edf54c"
Unverified Commit 6735a3ae authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[DistGB] enable sample etype neighbors on heterograph (#7095)

parent 3ebdee77
...@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt( ...@@ -143,8 +143,6 @@ def _sample_neighbors_graphbolt(
if isinstance(fanout, int): if isinstance(fanout, int):
fanout = torch.LongTensor([fanout]) fanout = torch.LongTensor([fanout])
assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout." assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout."
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."
return_eids = g.edge_attributes is not None and EID in g.edge_attributes return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids) subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)
...@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs): ...@@ -237,15 +235,15 @@ def _sample_neighbors(use_graphbolt, *args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
def _sample_etype_neighbors( def _sample_etype_neighbors_dgl(
local_g, local_g,
partition_book, partition_book,
seed_nodes, seed_nodes,
etype_offset,
fan_out, fan_out,
edge_dir, edge_dir="in",
prob, prob=None,
replace, replace=False,
etype_offset=None,
etype_sorted=False, etype_sorted=False,
): ):
"""Sample from local partition. """Sample from local partition.
...@@ -255,6 +253,8 @@ def _sample_etype_neighbors( ...@@ -255,6 +253,8 @@ def _sample_etype_neighbors(
The sampled results are stored in three vectors that store source nodes, destination nodes The sampled results are stored in three vectors that store source nodes, destination nodes
and edge IDs. and edge IDs.
""" """
assert etype_offset is not None, "The etype offset is not provided."
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid) local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype) local_ids = F.astype(local_ids, local_g.idtype)
...@@ -278,6 +278,43 @@ def _sample_etype_neighbors( ...@@ -278,6 +278,43 @@ def _sample_etype_neighbors(
return LocalSampledGraph(global_src, global_dst, global_eids) return LocalSampledGraph(global_src, global_dst, global_eids)
def _sample_etype_neighbors(use_graphbolt, *args, **kwargs):
"""Wrapper for sampling etype neighbors.
The actual sampling function depends on whether to use GraphBolt.
Parameters
----------
use_graphbolt : bool
Whether to use GraphBolt for sampling.
args : list
The arguments for the sampling function.
kwargs : dict
The keyword arguments for the sampling function.
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
tensor
The edge ID array.
tensor
The edge type ID array.
"""
func = (
_sample_neighbors_graphbolt
if use_graphbolt
else _sample_etype_neighbors_dgl
)
if use_graphbolt:
# GraphBolt does not require `etype_offset` and `etype_sorted`.
kwargs.pop("etype_offset", None)
kwargs.pop("etype_sorted", None)
return func(*args, **kwargs)
def _find_edges(local_g, partition_book, seed_edges): def _find_edges(local_g, partition_book, seed_edges):
"""Given an edge ID array, return the source """Given an edge ID array, return the source
and destination node ID array ``s`` and ``d`` in the local partition. and destination node ID array ``s`` and ``d`` in the local partition.
...@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request): ...@@ -426,6 +463,7 @@ class SamplingRequestEtype(Request):
prob=None, prob=None,
replace=False, replace=False,
etype_sorted=True, etype_sorted=True,
use_graphbolt=False,
): ):
self.seed_nodes = nodes self.seed_nodes = nodes
self.edge_dir = edge_dir self.edge_dir = edge_dir
...@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request): ...@@ -433,6 +471,7 @@ class SamplingRequestEtype(Request):
self.replace = replace self.replace = replace
self.fan_out = fan_out self.fan_out = fan_out
self.etype_sorted = etype_sorted self.etype_sorted = etype_sorted
self.use_graphbolt = use_graphbolt
def __setstate__(self, state): def __setstate__(self, state):
( (
...@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request): ...@@ -442,6 +481,7 @@ class SamplingRequestEtype(Request):
self.replace, self.replace,
self.fan_out, self.fan_out,
self.etype_sorted, self.etype_sorted,
self.use_graphbolt,
) = state ) = state
def __getstate__(self): def __getstate__(self):
...@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request): ...@@ -452,6 +492,7 @@ class SamplingRequestEtype(Request):
self.replace, self.replace,
self.fan_out, self.fan_out,
self.etype_sorted, self.etype_sorted,
self.use_graphbolt,
) )
def process_request(self, server_state): def process_request(self, server_state):
...@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request): ...@@ -468,15 +509,16 @@ class SamplingRequestEtype(Request):
else: else:
probs = None probs = None
res = _sample_etype_neighbors( res = _sample_etype_neighbors(
self.use_graphbolt,
local_g, local_g,
partition_book, partition_book,
self.seed_nodes, self.seed_nodes,
etype_offset,
self.fan_out, self.fan_out,
self.edge_dir, edge_dir=self.edge_dir,
probs, prob=probs,
self.replace, replace=self.replace,
self.etype_sorted, etype_offset=etype_offset,
etype_sorted=self.etype_sorted,
) )
return SubgraphResponse( return SubgraphResponse(
res.global_src, res.global_src,
...@@ -772,6 +814,7 @@ def sample_etype_neighbors( ...@@ -772,6 +814,7 @@ def sample_etype_neighbors(
prob=None, prob=None,
replace=False, replace=False,
etype_sorted=True, etype_sorted=True,
use_graphbolt=False,
): ):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
...@@ -825,6 +868,8 @@ def sample_etype_neighbors( ...@@ -825,6 +868,8 @@ def sample_etype_neighbors(
neighbors are sampled. If fanout == -1, all neighbors are collected. neighbors are sampled. If fanout == -1, all neighbors are collected.
etype_sorted : bool, optional etype_sorted : bool, optional
Indicates whether etypes are sorted. Indicates whether etypes are sorted.
use_graphbolt : bool, optional
Whether to use GraphBolt for sampling.
Returns Returns
------- -------
...@@ -882,6 +927,7 @@ def sample_etype_neighbors( ...@@ -882,6 +927,7 @@ def sample_etype_neighbors(
prob=_prob, prob=_prob,
replace=replace, replace=replace,
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
) )
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
...@@ -897,14 +943,15 @@ def sample_etype_neighbors( ...@@ -897,14 +943,15 @@ def sample_etype_neighbors(
for etype in g.canonical_etypes for etype in g.canonical_etypes
] ]
return _sample_etype_neighbors( return _sample_etype_neighbors(
use_graphbolt,
local_g, local_g,
partition_book, partition_book,
local_nids, local_nids,
etype_offset,
fanout, fanout,
edge_dir, edge_dir=edge_dir,
_prob, prob=_prob,
replace, replace=replace,
etype_offset=etype_offset,
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
) )
......
...@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client( ...@@ -508,6 +508,8 @@ def start_hetero_etype_sample_client(
fanout=3, fanout=3,
nodes={"n3": [0, 10, 99, 66, 124, 208]}, nodes={"n3": [0, 10, 99, 66, 124, 208]},
etype_sorted=False, etype_sorted=False,
use_graphbolt=False,
return_eids=False,
): ):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client( ...@@ -515,12 +517,14 @@ def start_hetero_etype_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["n1"].data assert "feat" in dist_graph.nodes["n1"].data
assert "feat" not in dist_graph.nodes["n2"].data assert "feat" not in dist_graph.nodes["n2"].data
assert "feat" not in dist_graph.nodes["n3"].data assert "feat" not in dist_graph.nodes["n3"].data
if dist_graph.local_partition is not None: if (not use_graphbolt) and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph # Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes()) local_nids = np.arange(local_g.num_nodes())
...@@ -533,11 +537,19 @@ def start_hetero_etype_sample_client( ...@@ -533,11 +537,19 @@ def start_hetero_etype_sample_client(
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
try: try:
# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"
sampled_graph = sample_etype_neighbors( sampled_graph = sample_etype_neighbors(
dist_graph, nodes, fanout, etype_sorted=etype_sorted dist_graph,
nodes,
fanout,
etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
) )
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] if sampled_graph.num_edges() > 0:
if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
block = None block = None
...@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle( ...@@ -689,7 +701,11 @@ def check_rpc_hetero_sampling_empty_shuffle(
def check_rpc_hetero_etype_sampling_shuffle( def check_rpc_hetero_etype_sampling_shuffle(
tmpdir, num_server, graph_formats=None tmpdir,
num_server,
graph_formats=None,
use_graphbolt=False,
return_eids=False,
): ):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle( ...@@ -706,6 +722,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
graph_formats=graph_formats, graph_formats=graph_formats,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle( ...@@ -713,7 +731,14 @@ def check_rpc_hetero_etype_sampling_shuffle(
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling", ["csc", "coo"]), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle( ...@@ -730,6 +755,8 @@ def check_rpc_hetero_etype_sampling_shuffle(
fanout, fanout,
nodes={"n3": [0, 10, 99, 66, 124, 208]}, nodes={"n3": [0, 10, 99, 66, 124, 208]},
etype_sorted=etype_sorted, etype_sorted=etype_sorted,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle( ...@@ -747,19 +774,26 @@ def check_rpc_hetero_etype_sampling_shuffle(
# These are global Ids after shuffling. # These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
# Check the node Ids and edge Ids. # Check the node Ids and edge Ids.
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype) orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src) assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): def check_rpc_hetero_etype_sampling_empty_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero(dense=True, empty=True) g = create_random_hetero(dense=True, empty=True)
...@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -774,6 +808,8 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -781,7 +817,14 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -791,7 +834,13 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nids["n3"], "n3") deg = get_degrees(g, orig_nids["n3"], "n3")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_hetero_etype_sample_client( block, gpb = start_hetero_etype_sample_client(
0, tmpdir, num_server > 1, fanout, nodes={"n3": empty_nids} 0,
tmpdir,
num_server > 1,
fanout,
nodes={"n3": empty_nids},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -848,7 +897,13 @@ def start_bipartite_sample_client( ...@@ -848,7 +897,13 @@ def start_bipartite_sample_client(
def start_bipartite_etype_sample_client( def start_bipartite_etype_sample_client(
rank, tmpdir, disable_shared_mem, fanout=3, nodes={} rank,
tmpdir,
disable_shared_mem,
fanout=3,
nodes={},
use_graphbolt=False,
return_eids=False,
): ):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
...@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client( ...@@ -856,11 +911,13 @@ def start_bipartite_etype_sample_client(
tmpdir / "test_sampling.json", rank tmpdir / "test_sampling.json", rank
) )
dgl.distributed.initialize("rpc_ip_config.txt") dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb) dist_graph = DistGraph(
"test_sampling", gpb=gpb, use_graphbolt=use_graphbolt
)
assert "feat" in dist_graph.nodes["user"].data assert "feat" in dist_graph.nodes["user"].data
assert "feat" in dist_graph.nodes["game"].data assert "feat" in dist_graph.nodes["game"].data
if dist_graph.local_partition is not None: if not use_graphbolt and dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph # Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes()) local_nids = np.arange(local_g.num_nodes())
...@@ -872,10 +929,13 @@ def start_bipartite_etype_sample_client( ...@@ -872,10 +929,13 @@ def start_bipartite_etype_sample_client(
if gpb is None: if gpb is None:
gpb = dist_graph.get_partition_book() gpb = dist_graph.get_partition_book()
sampled_graph = sample_etype_neighbors(dist_graph, nodes, fanout) sampled_graph = sample_etype_neighbors(
dist_graph, nodes, fanout, use_graphbolt=use_graphbolt
)
block = dgl.to_block(sampled_graph, nodes) block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0: if sampled_graph.num_edges() > 0:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID] if not use_graphbolt or return_eids:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client() dgl.distributed.exit_client()
return block, gpb return block, gpb
...@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle( ...@@ -1019,7 +1079,9 @@ def check_rpc_bipartite_sampling_shuffle(
assert np.all(F.asnumpy(orig_dst1) == orig_dst) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): def check_rpc_bipartite_etype_sampling_empty(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results""" """sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): ...@@ -1035,6 +1097,8 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): ...@@ -1042,7 +1106,14 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): ...@@ -1050,8 +1121,13 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
deg = get_degrees(g, orig_nids["game"], "game") deg = get_degrees(g, orig_nids["game"], "game")
empty_nids = F.nonzero_1d(deg == 0) empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_bipartite_etype_sample_client( block, _ = start_bipartite_etype_sample_client(
0, tmpdir, num_server > 1, nodes={"game": empty_nids, "user": [1]} 0,
tmpdir,
num_server > 1,
nodes={"game": empty_nids, "user": [1]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
...@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server): ...@@ -1064,7 +1140,9 @@ def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
assert len(block.etypes) == len(g.etypes) assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): def check_rpc_bipartite_etype_sampling_shuffle(
tmpdir, num_server, use_graphbolt=False, return_eids=False
):
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results""" """sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server) generate_ip_config("rpc_ip_config.txt", num_server, num_server)
...@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1080,6 +1158,8 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
num_hops=num_hops, num_hops=num_hops,
part_method="metis", part_method="metis",
return_mapping=True, return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
) )
pserver_list = [] pserver_list = []
...@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1087,7 +1167,14 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
for i in range(num_server): for i in range(num_server):
p = ctx.Process( p = ctx.Process(
target=start_server, target=start_server,
args=(i, tmpdir, num_server > 1, "test_sampling"), args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
use_graphbolt,
),
) )
p.start() p.start()
time.sleep(1) time.sleep(1)
...@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1097,7 +1184,13 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
deg = get_degrees(g, orig_nid_map["game"], "game") deg = get_degrees(g, orig_nid_map["game"], "game")
nids = F.nonzero_1d(deg > 0) nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_etype_sample_client( block, gpb = start_bipartite_etype_sample_client(
0, tmpdir, num_server > 1, fanout, nodes={"game": nids, "user": [0]} 0,
tmpdir,
num_server > 1,
fanout,
nodes={"game": nids, "user": [0]},
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
print("Done sampling") print("Done sampling")
for p in pserver_list: for p in pserver_list:
...@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): ...@@ -1110,13 +1203,18 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
# These are global Ids after shuffling. # These are global Ids after shuffling.
shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src) shuffled_src = F.gather_row(block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst) shuffled_dst = F.gather_row(block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src)) orig_src = F.asnumpy(F.gather_row(orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst)) orig_dst = F.asnumpy(F.gather_row(orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid)) assert np.all(
F.asnumpy(g.has_edges_between(orig_src, orig_dst, etype=etype))
)
if use_graphbolt and not return_eids:
continue
# Check the node Ids and edge Ids. # Check the node Ids and edge Ids.
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[c_etype], shuffled_eid))
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype) orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src) assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst) assert np.all(F.asnumpy(orig_dst1) == orig_dst)
...@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle( ...@@ -1173,7 +1271,7 @@ def test_rpc_hetero_sampling_empty_shuffle(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]] "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]]
) )
def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats): def test_rpc_hetero_etype_sampling_shuffle_dgl(num_server, graph_formats):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats): ...@@ -1183,12 +1281,33 @@ def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_hetero_etype_sampling_empty_shuffle(num_server): @pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_shuffle_graphbolt(num_server, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=True,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_hetero_etype_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_hetero_etype_sampling_empty_shuffle( check_rpc_hetero_etype_sampling_empty_shuffle(
Path(tmpdirname), num_server Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
) )
...@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids): ...@@ -1219,19 +1338,37 @@ def test_rpc_bipartite_sampling_shuffle(num_server, use_graphbolt, return_eids):
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_empty_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server) check_rpc_bipartite_etype_sampling_empty(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_server", [1])
def test_rpc_bipartite_etype_sampling_shuffle(num_server): @pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_bipartite_etype_sampling_shuffle(
num_server, use_graphbolt, return_eids
):
reset_envs() reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed" os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_bipartite_etype_sampling_shuffle(
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)
def check_standalone_sampling(tmpdir): def check_standalone_sampling(tmpdir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment