Unverified Commit 32be4a8e authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[GraphBolt] Store `ORIGINAL_EDGE_ID` in CSCSamplingGraph's `edge_attributes`. (#6399)

parent 241760a5
......@@ -8,7 +8,7 @@ from typing import Dict, Optional, Union
import torch
from ...base import ETYPE
from ...base import EID, ETYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
......@@ -810,13 +810,16 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
# Assign edge attributes according to the original eids mapping.
edge_attributes = {ORIGINAL_EDGE_ID: homo_g.edata[EID][edge_ids]}
return CSCSamplingGraph(
torch.ops.graphbolt.from_csc(
indptr,
indices,
node_type_offset,
type_per_edge,
None,
edge_attributes,
),
metadata,
)
......@@ -1296,6 +1296,15 @@ def test_from_dglgraph_homogeneous():
dgl_g = dgl.rand_graph(1000, 10 * 1000)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=True)
# Get the COO representation of the CSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = gb_g.indices
columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
original_edge_ids = gb_g.edge_attributes[gb.ORIGINAL_EDGE_ID]
assert torch.all(dgl_g.edges()[0][original_edge_ids] == rows)
assert torch.all(dgl_g.edges()[1][original_edge_ids] == columns)
assert gb_g.total_num_nodes == dgl_g.num_nodes()
assert gb_g.total_num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000]))
......@@ -1328,6 +1337,40 @@ def test_from_dglgraph_heterogeneous():
)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False)
# `reverse_node_id` is used to map the node id in CSCSamplingGraph to the
# node id in Hetero-DGLGraph.
num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1]
reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])
# Get the COO representation of the CSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = reverse_node_id[gb_g.indices]
columns = reverse_node_id[
torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
]
# Check the order of etypes in DGLGraph is the same as CSCSamplingGraph.
assert (
# Since the etypes in CSCSamplingGraph is "srctype:etype:dsttype",
# we need to split the string and get the middle part.
list(
map(
lambda ss: ss.split(":")[1],
gb_g.metadata.edge_type_to_id.keys(),
)
)
== dgl_g.etypes
)
# Use ORIGINAL_EDGE_ID to check if the edge mapping is correct.
for edge_idx in range(gb_g.total_num_edges):
hetero_graph_idx = gb_g.type_per_edge[edge_idx]
original_edge_id = gb_g.edge_attributes[gb.ORIGINAL_EDGE_ID][edge_idx]
edge_type = dgl_g.etypes[hetero_graph_idx]
dgl_edge_pairs = dgl_g.edges(etype=edge_type)
assert dgl_edge_pairs[0][original_edge_id] == rows[edge_idx]
assert dgl_edge_pairs[1][original_edge_id] == columns[edge_idx]
assert gb_g.total_num_nodes == dgl_g.num_nodes()
assert gb_g.total_num_edges == dgl_g.num_edges()
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 6, 12, 18, 25]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment