Unverified Commit c51516a8 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix type_per_edge when convert DGLGraph to CSCSamplingGraph (#6314)

parent 53153b42
......@@ -729,7 +729,10 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE]
# Sort edge type according to columns as `csc` is used.
_, dst = homo_g.edges()
_, dst_indices = torch.sort(dst)
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][dst_indices]
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
......
......@@ -1048,12 +1048,12 @@ def test_multiprocessing_with_shared_memory():
)
def test_from_dglgraph_homogeneous():
dgl_g = dgl.rand_graph(1000, 10 * 1000)
gb_g = gb.from_dglgraph(dgl_g)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=True)
assert gb_g.num_nodes == dgl_g.num_nodes()
assert gb_g.num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000]))
assert torch.all(gb_g.type_per_edge == 0)
assert gb_g.type_per_edge is None
assert gb_g.metadata.node_type_to_id == {"_N": 0}
assert gb_g.metadata.edge_type_to_id == {"_N:_E:_N": 0}
......@@ -1063,46 +1063,27 @@ def test_from_dglgraph_homogeneous():
reason="Graph on GPU is not supported yet.",
)
def test_from_dglgraph_heterogeneous():
def create_random_hetero():
num_nodes = {"n1": 1000, "n2": 1010, "n3": 1020}
etypes = [
("n1", "r12", "n2"),
("n2", "r21", "n1"),
("n1", "r13", "n3"),
("n2", "r23", "n3"),
]
edges = {}
for etype in etypes:
src_ntype, _, dst_ntype = etype
arr = spsp.random(
num_nodes[src_ntype],
num_nodes[dst_ntype],
density=0.001,
format="coo",
random_state=100,
dgl_g = dgl.heterograph(
{
("author", "writes", "paper"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),
("paper", "cites", "paper"): ([2, 3, 4, 5, 6], [1, 2, 3, 4, 5]),
}
)
edges[etype] = (arr.row, arr.col)
return dgl.heterograph(edges, num_nodes)
dgl_g = create_random_hetero()
gb_g = gb.from_dglgraph(dgl_g)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False)
assert gb_g.num_nodes == dgl_g.num_nodes()
assert gb_g.num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 13]))
assert torch.equal(
gb_g.node_type_offset, torch.tensor([0, 1000, 2010, 3030])
gb_g.type_per_edge, torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
)
assert torch.all(gb_g.type_per_edge[:-1] <= gb_g.type_per_edge[1:])
assert gb_g.metadata.node_type_to_id == {
"n1": 0,
"n2": 1,
"n3": 2,
"author": 0,
"paper": 1,
}
assert gb_g.metadata.edge_type_to_id == {
"n1:r12:n2": 0,
"n1:r13:n3": 1,
"n2:r21:n1": 2,
"n2:r23:n3": 3,
"author:writes:paper": 0,
"paper:cites:paper": 1,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment