Unverified Commit 5eda8440 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix type_per_edge (#6317)

parent edcecdd0
......@@ -725,14 +725,12 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix.
indptr, indices, _ = homo_g.adj_tensors("csc")
indptr, indices, edge_ids = homo_g.adj_tensors("csc")
ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
# Sort edge type according to columns as `csc` is used.
_, dst = homo_g.edges()
_, dst_indices = torch.sort(dst)
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][dst_indices]
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
......
......@@ -1065,25 +1065,43 @@ def test_from_dglgraph_homogeneous():
def test_from_dglgraph_heterogeneous():
dgl_g = dgl.heterograph(
{
("author", "writes", "paper"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),
("paper", "cites", "paper"): ([2, 3, 4, 5, 6], [1, 2, 3, 4, 5]),
("author", "writes", "paper"): (
[1, 2, 3, 4, 5, 2],
[1, 2, 3, 4, 5, 4],
),
("author", "affiliated_with", "institution"): (
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
),
("paper", "has_topic", "field"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),
("paper", "cites", "paper"): (
[2, 3, 4, 5, 6, 1],
[1, 2, 3, 4, 5, 4],
),
}
)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False)
assert gb_g.num_nodes == dgl_g.num_nodes()
assert gb_g.num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 13]))
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 12, 18, 25]))
assert torch.equal(
gb_g.type_per_edge, torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
gb_g.type_per_edge,
torch.tensor(
[3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2]
),
)
assert gb_g.metadata.node_type_to_id == {
"author": 0,
"paper": 1,
"field": 1,
"institution": 2,
"paper": 3,
}
assert gb_g.metadata.edge_type_to_id == {
"author:writes:paper": 0,
"paper:cites:paper": 1,
"author:affiliated_with:institution": 0,
"author:writes:paper": 1,
"paper:cites:paper": 2,
"paper:has_topic:field": 3,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment