"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "0afc8e9e2f2a0a2ca707057fe6523bed98451bb6"
Unverified Commit ac49220c authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Convert a data block to dgl graphs (#6228)

parent 268f4569
...@@ -5,6 +5,9 @@ from typing import Dict, List, Tuple, Union ...@@ -5,6 +5,9 @@ from typing import Dict, List, Tuple, Union
import torch import torch
import dgl
from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
__all__ = ["MiniBatch"] __all__ = ["MiniBatch"]
...@@ -122,3 +125,90 @@ class MiniBatch: ...@@ -122,3 +125,90 @@ class MiniBatch:
Representation of compacted nodes corresponding to 'negative_tail', where Representation of compacted nodes corresponding to 'negative_tail', where
all node ids inside are compacted. all node ids inside are compacted.
""" """
def to_dgl_graphs(self):
"""Transforming a data graph into DGL graphs necessitates constructing a
graphical structure and assigning features to the nodes and edges within
the graphs.
"""
if not self.sampled_subgraphs:
return None
is_heterogeneous = isinstance(
self.sampled_subgraphs[0].node_pairs, Dict
)
if is_heterogeneous:
graphs = []
for subgraph in self.sampled_subgraphs:
graphs.append(
dgl.heterograph(
{
etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items()
}
)
)
else:
graphs = [
dgl.graph(subgraph.node_pairs)
for subgraph in self.sampled_subgraphs
]
if is_heterogeneous:
# Assign node features to the outermost layer's nodes.
if self.node_features:
for (
node_type,
feature_name,
), feature in self.node_features.items():
graphs[0].nodes[node_type].data[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for (
edge_type,
feature_name,
), feature in edge_feature.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
feature_name
] = feature
# Assign reverse node ids to the outermost layer's nodes.
reverse_row_node_ids = self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids:
for node_type, reverse_ids in reverse_row_node_ids.items():
graphs[0].nodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
if subgraph.reverse_edge_ids:
for (
edge_type,
reverse_ids,
) in subgraph.reverse_edge_ids.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID
] = reverse_ids
else:
# Assign node features to the outermost layer's nodes.
if self.node_features:
for feature_name, feature in self.node_features.items():
graphs[0].ndata[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for feature_name, feature in edge_feature.items():
graph.edata[feature_name] = feature
# Assign reverse node ids.
reverse_row_node_ids = self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids is not None:
graphs[0].ndata[dgl.NID] = reverse_row_node_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
graph.edata[dgl.EID] = subgraph.reverse_edge_ids
return graphs
import dgl
import dgl.graphbolt as gb
import torch
def test_to_dgl_graphs_hetero():
relation = "A:relation:B"
node_pairs = {relation: (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))}
reverse_column_node_ids = {"B": torch.tensor([10, 11, 12, 13, 14, 16])}
reverse_row_node_ids = {
"A": torch.tensor([5, 9, 7]),
"B": torch.tensor([10, 11, 12, 13, 14, 16]),
}
reverse_edge_ids = {relation: torch.tensor([19, 20, 21])}
node_features = {
("A", "x"): torch.randint(0, 10, (3,)),
("B", "y"): torch.randint(0, 10, (6,)),
}
edge_features = {(relation, "x"): torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[relation][0])
assert torch.equal(g.edges()[1], node_pairs[relation][1])
assert torch.equal(g.ndata[dgl.NID]["A"], reverse_row_node_ids["A"])
assert torch.equal(g.ndata[dgl.NID]["B"], reverse_row_node_ids["B"])
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids[relation])
assert torch.equal(g.nodes["A"].data["x"], node_features[("A", "x")])
assert torch.equal(g.nodes["B"].data["y"], node_features[("B", "y")])
assert torch.equal(
g.edges[gb.etype_str_to_tuple(relation)].data["x"],
edge_features[(relation, "x")],
)
def test_to_dgl_graphs_homo():
node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))
reverse_column_node_ids = torch.tensor([10, 11, 12])
reverse_row_node_ids = torch.tensor([10, 11, 12, 13, 14, 16])
reverse_edge_ids = torch.tensor([19, 20, 21])
node_features = {"x": torch.randint(0, 10, (6,))}
edge_features = {"x": torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[0])
assert torch.equal(g.edges()[1], node_pairs[1])
assert torch.equal(g.ndata[dgl.NID], reverse_row_node_ids)
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids)
assert torch.equal(g.ndata["x"], node_features["x"])
assert torch.equal(g.edata["x"], edge_features["x"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment