"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dd07b19e27b737d844f62a8107228591f8d7bca8"
Unverified Commit a93b6fec authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify random hetero graph generation in tests. (#6500)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 1d86b796
import os
import dgl
import dgl.graphbolt as gb
import numpy as np
......@@ -43,24 +44,45 @@ def get_metadata(num_ntypes, num_etypes):
return gb.GraphMetadata(ntypes, etypes)
def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):
ntypes = {f"n{i}": num_nodes // num_ntypes for i in range(num_ntypes)}
if num_nodes % num_ntypes != 0:
ntypes["n0"] += num_nodes % num_ntypes
etypes = []
count = 0
while count < num_etypes:
for n1 in range(num_ntypes):
for n2 in range(num_ntypes):
if count >= num_etypes:
break
etypes.append((f"n{n1}", f"e{count}", f"n{n2}"))
count += 1
return ntypes, etypes
def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
csc_indptr, indices = random_homo_graph(num_nodes, num_edges)
metadata = get_metadata(num_ntypes, num_etypes)
# Randomly get node type split point.
node_type_offset = torch.sort(
torch.randint(0, num_nodes, (num_ntypes + 1,))
)[0]
node_type_offset[0] = 0
node_type_offset[-1] = num_nodes
type_per_edge = []
for i in range(num_nodes):
num = csc_indptr[i + 1] - csc_indptr[i]
type_per_edge.append(
torch.sort(torch.randint(0, num_etypes, (num,)))[0]
ntypes, etypes = get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes)
edges = {}
for step, etype in enumerate(etypes):
src_ntype, _, dst_ntype = etype
num_e = num_edges // num_etypes + (
0 if step != 0 else num_edges % num_etypes
)
type_per_edge = torch.cat(type_per_edge, dim=0)
return (csc_indptr, indices, node_type_offset, type_per_edge, metadata)
if ntypes[src_ntype] == 0 or ntypes[dst_ntype] == 0:
continue
src = torch.randint(0, ntypes[src_ntype], (num_e,))
dst = torch.randint(0, ntypes[dst_ntype], (num_e,))
edges[etype] = (src, dst)
gb_g = gb.from_dglgraph(dgl.heterograph(edges, ntypes))
return (
gb_g.csc_indptr,
gb_g.indices,
gb_g.node_type_offset,
gb_g.type_per_edge,
gb_g.metadata,
)
def random_homo_graphbolt_graph(
......
......@@ -192,3 +192,66 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
assert len(list(neighbor_dp)) == 5
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Random_Hetero_Graph(labor):
num_nodes = 5
num_edges = 9
num_ntypes = 3
num_etypes = 3
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes
)
edge_attributes = {
"A1": torch.randn(num_edges),
"A2": torch.randn(num_edges),
}
graph = gb.from_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
)
itemset = gb.ItemSetDict(
{
"n2": gb.ItemSet(torch.tensor([0]), names="seed_nodes"),
"n1": gb.ItemSet(torch.tensor([1]), names="seed_nodes"),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(item_sampler, graph, fanouts, replace=True)
for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.node_pairs.items():
assert torch.equal(
torch.ge(value[0], torch.zeros(len(value[0]))),
torch.ones(len(value[0])),
)
assert torch.equal(
torch.ge(value[1], torch.zeros(len(value[1]))),
torch.ones(len(value[1])),
)
for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value))),
torch.ones(len(value)),
)
for _, value in sampledsubgraph.original_row_node_ids.items():
assert torch.equal(
torch.ge(value, torch.zeros(len(value))),
torch.ones(len(value)),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment