Unverified Commit e3bf1c0e authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Convert minibatch to dgl block instead of graph (#6318)

parent eb40ed55
......@@ -126,10 +126,10 @@ class MiniBatch:
all node ids inside are compacted.
"""
def to_dgl_graphs(self):
"""Transforming a data graph into DGL graphs necessitates constructing a
def to_dgl_blocks(self):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing a
graphical structure and assigning features to the nodes and edges within
the graphs.
the blocks.
"""
if not self.sampled_subgraphs:
return None
......@@ -138,77 +138,90 @@ class MiniBatch:
self.sampled_subgraphs[0].node_pairs, Dict
)
if is_heterogeneous:
graphs = []
blocks = []
for subgraph in self.sampled_subgraphs:
graphs.append(
dgl.heterograph(
{
reverse_row_node_ids = subgraph.reverse_row_node_ids
assert (
reverse_row_node_ids is not None
), "Missing `reverse_row_node_ids` in sampled subgraph."
reverse_column_node_ids = subgraph.reverse_column_node_ids
assert (
reverse_column_node_ids is not None
), "Missing `reverse_column_node_ids` in sampled subgraph."
if is_heterogeneous:
node_pairs = {
etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items()
}
num_src_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_row_node_ids.items()
}
num_dst_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_column_node_ids.items()
}
else:
node_pairs = subgraph.node_pairs
num_src_nodes = reverse_row_node_ids.size(0)
num_dst_nodes = reverse_column_node_ids.size(0)
blocks.append(
dgl.create_block(
node_pairs,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
)
)
else:
graphs = [
dgl.graph(subgraph.node_pairs)
for subgraph in self.sampled_subgraphs
]
if is_heterogeneous:
# Assign node features to the outermost layer's nodes.
# Assign node features to the outermost layer's source nodes.
if self.node_features:
for (
node_type,
feature_name,
), feature in self.node_features.items():
graphs[0].nodes[node_type].data[feature_name] = feature
blocks[0].srcnodes[node_type].data[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for block, edge_feature in zip(blocks, self.edge_features):
for (
edge_type,
feature_name,
), feature in edge_feature.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
block.edges[etype_str_to_tuple(edge_type)].data[
feature_name
] = feature
# Assign reverse node ids to the outermost layer's nodes.
reverse_row_node_ids = self.sampled_subgraphs[
# Assign reverse node ids to the outermost layer's source nodes.
for node_type, reverse_ids in self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids:
for node_type, reverse_ids in reverse_row_node_ids.items():
graphs[0].nodes[node_type].data[dgl.NID] = reverse_ids
].reverse_row_node_ids.items():
blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids:
for (
edge_type,
reverse_ids,
) in subgraph.reverse_edge_ids.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
block.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID
] = reverse_ids
else:
# Assign node features to the outermost layer's nodes.
# Assign node features to the outermost layer's source nodes.
if self.node_features:
for feature_name, feature in self.node_features.items():
graphs[0].ndata[feature_name] = feature
blocks[0].srcdata[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for block, edge_feature in zip(blocks, self.edge_features):
for feature_name, feature in edge_feature.items():
graph.edata[feature_name] = feature
# Assign reverse node ids.
reverse_row_node_ids = self.sampled_subgraphs[
block.edata[feature_name] = feature
blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids is not None:
graphs[0].ndata[dgl.NID] = reverse_row_node_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
graph.edata[dgl.EID] = subgraph.reverse_edge_ids
block.edata[dgl.EID] = subgraph.reverse_edge_ids
return graphs
return blocks
import dgl
import dgl.graphbolt as gb
import torch
def test_to_dgl_blocks_hetero():
relation = "A:r:B"
reverse_relation = "B:rr:A"
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
reverse_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
reverse_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
reverse_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))},
{(relation, "x"): torch.randint(0, 10, (2,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
)
)
blocks = gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
).to_dgl_blocks()
etype = gb.etype_str_to_tuple(relation)
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], reverse_edge_ids[i][relation]
)
assert torch.equal(
block.edges[etype].data["x"],
edge_features[i][(relation, "x")],
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], reverse_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], reverse_row_node_ids[0]["B"]
)
assert torch.equal(
blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")]
)
test_to_dgl_blocks_hetero()
def test_to_dgl_blocks_homo():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
reverse_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
reverse_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
reverse_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.randint(0, 10, (4,))}
edge_features = [
{"x": torch.randint(0, 10, (6,))},
{"x": torch.randint(0, 10, (3,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
)
)
blocks = gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
).to_dgl_blocks()
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], reverse_edge_ids[i])
assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
import dgl
import dgl.graphbolt as gb
import torch
def test_to_dgl_graphs_hetero():
relation = "A:relation:B"
node_pairs = {relation: (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))}
reverse_column_node_ids = {"B": torch.tensor([10, 11, 12, 13, 14, 16])}
reverse_row_node_ids = {
"A": torch.tensor([5, 9, 7]),
"B": torch.tensor([10, 11, 12, 13, 14, 16]),
}
reverse_edge_ids = {relation: torch.tensor([19, 20, 21])}
node_features = {
("A", "x"): torch.randint(0, 10, (3,)),
("B", "y"): torch.randint(0, 10, (6,)),
}
edge_features = {(relation, "x"): torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[relation][0])
assert torch.equal(g.edges()[1], node_pairs[relation][1])
assert torch.equal(g.ndata[dgl.NID]["A"], reverse_row_node_ids["A"])
assert torch.equal(g.ndata[dgl.NID]["B"], reverse_row_node_ids["B"])
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids[relation])
assert torch.equal(g.nodes["A"].data["x"], node_features[("A", "x")])
assert torch.equal(g.nodes["B"].data["y"], node_features[("B", "y")])
assert torch.equal(
g.edges[gb.etype_str_to_tuple(relation)].data["x"],
edge_features[(relation, "x")],
)
def test_to_dgl_graphs_homo():
node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))
reverse_column_node_ids = torch.tensor([10, 11, 12])
reverse_row_node_ids = torch.tensor([10, 11, 12, 13, 14, 16])
reverse_edge_ids = torch.tensor([19, 20, 21])
node_features = {"x": torch.randint(0, 10, (6,))}
edge_features = {"x": torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[0])
assert torch.equal(g.edges()[1], node_pairs[1])
assert torch.equal(g.ndata[dgl.NID], reverse_row_node_ids)
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids)
assert torch.equal(g.ndata["x"], node_features["x"])
assert torch.equal(g.edata["x"], edge_features["x"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment