Unverified Commit 8cf5ad84 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add `to_pyg_data` for MiniBatch (#7076)

parent 0504bc2c
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import dgl import dgl
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import etype_str_to_tuple from .base import etype_str_to_tuple, expand_indptr
from .internal import get_attributes from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
...@@ -474,6 +474,50 @@ class MiniBatch: ...@@ -474,6 +474,50 @@ class MiniBatch:
else: else:
return None return None
def to_pyg_data(self):
"""Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of
features cannot be more than one.
"""
from torch_geometric.data import Data
if self.sampled_subgraphs is None:
edge_index = None
else:
col_nodes = []
row_nodes = []
for subgraph in self.sampled_subgraphs:
if subgraph is None:
continue
sampled_csc = subgraph.sampled_csc
indptr = sampled_csc.indptr
indices = sampled_csc.indices
expanded_indptr = expand_indptr(
indptr, dtype=indices.dtype, output_size=len(indices)
)
col_nodes.append(expanded_indptr)
row_nodes.append(indices)
col_nodes = torch.cat(col_nodes)
row_nodes = torch.cat(row_nodes)
edge_index = torch.unique(
torch.stack((col_nodes, row_nodes)), dim=1
)
if self.node_features is None:
node_features = None
else:
assert (
len(self.node_features) == 1
), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values()))
pyg_data = Data(
x=node_features,
edge_index=edge_index,
y=self.labels,
)
return pyg_data
def to(self, device: torch.device): # pylint: disable=invalid-name def to(self, device: torch.device): # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection.""" """Copy `MiniBatch` to the specified device using reflection."""
......
...@@ -859,3 +859,86 @@ def test_dgl_link_predication_hetero(mode): ...@@ -859,3 +859,86 @@ def test_dgl_link_predication_hetero(mode):
minibatch.negative_node_pairs[etype][1], minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype], minibatch.compacted_negative_dsts[etype],
) )
def test_to_pyg_data():
test_subgraph_a = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
original_column_node_ids=torch.tensor([10, 11, 12, 13]),
original_row_node_ids=torch.tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=torch.tensor([10, 11, 12, 13]),
)
test_subgraph_b = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([1, 2, 0]),
),
original_row_node_ids=torch.tensor([10, 11, 12]),
original_edge_ids=torch.tensor([10, 15, 17]),
original_column_node_ids=torch.tensor([10, 11]),
)
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 3], [0, 1, 0, 1, 2, 1, 2, 2]]
)
expected_node_features = torch.tensor([[1], [2], [3], [4]])
expected_labels = torch.tensor([0, 1])
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a, test_subgraph_b],
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)
# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.edge_index is None, "Edge index should be none."
# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features=None,
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.x is None, "Node features should be None."
# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features={"feat": expected_node_features},
labels=None,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.y is None, "Labels should be None."
# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
},
labels=expected_labels,
)
try:
pyg_data = test_minibatch.to_pyg_data()
assert (
pyg_data.x is None,
), "Multiple features case should raise an error."
except AssertionError as e:
assert (
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment