Unverified Commit 5eda8440 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix type_per_edge (#6317)

parent edcecdd0
...@@ -725,14 +725,12 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph: ...@@ -725,14 +725,12 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
metadata = GraphMetadata(node_type_to_id, edge_type_to_id) metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix. # Obtain CSC matrix.
indptr, indices, _ = homo_g.adj_tensors("csc") indptr, indices, edge_ids = homo_g.adj_tensors("csc")
ntype_count.insert(0, 0) ntype_count.insert(0, 0)
node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0) node_type_offset = torch.cumsum(torch.LongTensor(ntype_count), 0)
# Sort edge type according to columns as `csc` is used. # Assign edge type according to the order of CSC matrix.
_, dst = homo_g.edges() type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
_, dst_indices = torch.sort(dst)
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][dst_indices]
return CSCSamplingGraph( return CSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_csc(
......
...@@ -1065,25 +1065,43 @@ def test_from_dglgraph_homogeneous(): ...@@ -1065,25 +1065,43 @@ def test_from_dglgraph_homogeneous():
def test_from_dglgraph_heterogeneous(): def test_from_dglgraph_heterogeneous():
dgl_g = dgl.heterograph( dgl_g = dgl.heterograph(
{ {
("author", "writes", "paper"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), ("author", "writes", "paper"): (
("paper", "cites", "paper"): ([2, 3, 4, 5, 6], [1, 2, 3, 4, 5]), [1, 2, 3, 4, 5, 2],
[1, 2, 3, 4, 5, 4],
),
("author", "affiliated_with", "institution"): (
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
),
("paper", "has_topic", "field"): ([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]),
("paper", "cites", "paper"): (
[2, 3, 4, 5, 6, 1],
[1, 2, 3, 4, 5, 4],
),
} }
) )
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False) gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False)
assert gb_g.num_nodes == dgl_g.num_nodes() assert gb_g.num_nodes == dgl_g.num_nodes()
assert gb_g.num_edges == dgl_g.num_edges() assert gb_g.num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 13])) assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 12, 18, 25]))
assert torch.equal( assert torch.equal(
gb_g.type_per_edge, torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) gb_g.type_per_edge,
torch.tensor(
[3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2]
),
) )
assert gb_g.metadata.node_type_to_id == { assert gb_g.metadata.node_type_to_id == {
"author": 0, "author": 0,
"paper": 1, "field": 1,
"institution": 2,
"paper": 3,
} }
assert gb_g.metadata.edge_type_to_id == { assert gb_g.metadata.edge_type_to_id == {
"author:writes:paper": 0, "author:affiliated_with:institution": 0,
"paper:cites:paper": 1, "author:writes:paper": 1,
"paper:cites:paper": 2,
"paper:has_topic:field": 3,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment