Unverified Commit e3bf1c0e authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Convert minibatch to dgl block instead of graph (#6318)

parent eb40ed55
...@@ -126,10 +126,10 @@ class MiniBatch: ...@@ -126,10 +126,10 @@ class MiniBatch:
all node ids inside are compacted. all node ids inside are compacted.
""" """
def to_dgl_graphs(self): def to_dgl_blocks(self):
"""Transforming a data graph into DGL graphs necessitates constructing a """Transforming a `MiniBatch` into DGL blocks necessitates constructing a
graphical structure and assigning features to the nodes and edges within graphical structure and assigning features to the nodes and edges within
the graphs. the blocks.
""" """
if not self.sampled_subgraphs: if not self.sampled_subgraphs:
return None return None
...@@ -138,77 +138,90 @@ class MiniBatch: ...@@ -138,77 +138,90 @@ class MiniBatch:
self.sampled_subgraphs[0].node_pairs, Dict self.sampled_subgraphs[0].node_pairs, Dict
) )
if is_heterogeneous: blocks = []
graphs = []
for subgraph in self.sampled_subgraphs: for subgraph in self.sampled_subgraphs:
graphs.append( reverse_row_node_ids = subgraph.reverse_row_node_ids
dgl.heterograph( assert (
{ reverse_row_node_ids is not None
), "Missing `reverse_row_node_ids` in sampled subgraph."
reverse_column_node_ids = subgraph.reverse_column_node_ids
assert (
reverse_column_node_ids is not None
), "Missing `reverse_column_node_ids` in sampled subgraph."
if is_heterogeneous:
node_pairs = {
etype_str_to_tuple(etype): v etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items() for etype, v in subgraph.node_pairs.items()
} }
num_src_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_row_node_ids.items()
}
num_dst_nodes = {
ntype: nodes.size(0)
for ntype, nodes in reverse_column_node_ids.items()
}
else:
node_pairs = subgraph.node_pairs
num_src_nodes = reverse_row_node_ids.size(0)
num_dst_nodes = reverse_column_node_ids.size(0)
blocks.append(
dgl.create_block(
node_pairs,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
) )
) )
else:
graphs = [
dgl.graph(subgraph.node_pairs)
for subgraph in self.sampled_subgraphs
]
if is_heterogeneous: if is_heterogeneous:
# Assign node features to the outermost layer's nodes. # Assign node features to the outermost layer's source nodes.
if self.node_features: if self.node_features:
for ( for (
node_type, node_type,
feature_name, feature_name,
), feature in self.node_features.items(): ), feature in self.node_features.items():
graphs[0].nodes[node_type].data[feature_name] = feature blocks[0].srcnodes[node_type].data[feature_name] = feature
# Assign edge features. # Assign edge features.
if self.edge_features: if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features): for block, edge_feature in zip(blocks, self.edge_features):
for ( for (
edge_type, edge_type,
feature_name, feature_name,
), feature in edge_feature.items(): ), feature in edge_feature.items():
graph.edges[etype_str_to_tuple(edge_type)].data[ block.edges[etype_str_to_tuple(edge_type)].data[
feature_name feature_name
] = feature ] = feature
# Assign reverse node ids to the outermost layer's nodes. # Assign reverse node ids to the outermost layer's source nodes.
reverse_row_node_ids = self.sampled_subgraphs[ for node_type, reverse_ids in self.sampled_subgraphs[
0 0
].reverse_row_node_ids ].reverse_row_node_ids.items():
if reverse_row_node_ids: blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids
for node_type, reverse_ids in reverse_row_node_ids.items():
graphs[0].nodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids. # Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs): for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids: if subgraph.reverse_edge_ids:
for ( for (
edge_type, edge_type,
reverse_ids, reverse_ids,
) in subgraph.reverse_edge_ids.items(): ) in subgraph.reverse_edge_ids.items():
graph.edges[etype_str_to_tuple(edge_type)].data[ block.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID dgl.EID
] = reverse_ids ] = reverse_ids
else: else:
# Assign node features to the outermost layer's nodes. # Assign node features to the outermost layer's source nodes.
if self.node_features: if self.node_features:
for feature_name, feature in self.node_features.items(): for feature_name, feature in self.node_features.items():
graphs[0].ndata[feature_name] = feature blocks[0].srcdata[feature_name] = feature
# Assign edge features. # Assign edge features.
if self.edge_features: if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features): for block, edge_feature in zip(blocks, self.edge_features):
for feature_name, feature in edge_feature.items(): for feature_name, feature in edge_feature.items():
graph.edata[feature_name] = feature block.edata[feature_name] = feature
# Assign reverse node ids. blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
reverse_row_node_ids = self.sampled_subgraphs[
0 0
].reverse_row_node_ids ].reverse_row_node_ids
if reverse_row_node_ids is not None:
graphs[0].ndata[dgl.NID] = reverse_row_node_ids
# Assign reverse edges ids. # Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs): for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None: if subgraph.reverse_edge_ids is not None:
graph.edata[dgl.EID] = subgraph.reverse_edge_ids block.edata[dgl.EID] = subgraph.reverse_edge_ids
return graphs return blocks
import dgl
import dgl.graphbolt as gb
import torch
def test_to_dgl_blocks_hetero():
relation = "A:r:B"
reverse_relation = "B:rr:A"
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
reverse_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
reverse_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
reverse_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))},
{(relation, "x"): torch.randint(0, 10, (2,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
)
)
blocks = gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
).to_dgl_blocks()
etype = gb.etype_str_to_tuple(relation)
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], reverse_edge_ids[i][relation]
)
assert torch.equal(
block.edges[etype].data["x"],
edge_features[i][(relation, "x")],
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], reverse_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], reverse_row_node_ids[0]["B"]
)
assert torch.equal(
blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")]
)
test_to_dgl_blocks_hetero()
def test_to_dgl_blocks_homo():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
reverse_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
reverse_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
reverse_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.randint(0, 10, (4,))}
edge_features = [
{"x": torch.randint(0, 10, (6,))},
{"x": torch.randint(0, 10, (3,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
)
)
blocks = gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
).to_dgl_blocks()
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], reverse_edge_ids[i])
assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
import dgl
import dgl.graphbolt as gb
import torch
def test_to_dgl_graphs_hetero():
relation = "A:relation:B"
node_pairs = {relation: (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))}
reverse_column_node_ids = {"B": torch.tensor([10, 11, 12, 13, 14, 16])}
reverse_row_node_ids = {
"A": torch.tensor([5, 9, 7]),
"B": torch.tensor([10, 11, 12, 13, 14, 16]),
}
reverse_edge_ids = {relation: torch.tensor([19, 20, 21])}
node_features = {
("A", "x"): torch.randint(0, 10, (3,)),
("B", "y"): torch.randint(0, 10, (6,)),
}
edge_features = {(relation, "x"): torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[relation][0])
assert torch.equal(g.edges()[1], node_pairs[relation][1])
assert torch.equal(g.ndata[dgl.NID]["A"], reverse_row_node_ids["A"])
assert torch.equal(g.ndata[dgl.NID]["B"], reverse_row_node_ids["B"])
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids[relation])
assert torch.equal(g.nodes["A"].data["x"], node_features[("A", "x")])
assert torch.equal(g.nodes["B"].data["y"], node_features[("B", "y")])
assert torch.equal(
g.edges[gb.etype_str_to_tuple(relation)].data["x"],
edge_features[(relation, "x")],
)
def test_to_dgl_graphs_homo():
node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))
reverse_column_node_ids = torch.tensor([10, 11, 12])
reverse_row_node_ids = torch.tensor([10, 11, 12, 13, 14, 16])
reverse_edge_ids = torch.tensor([19, 20, 21])
node_features = {"x": torch.randint(0, 10, (6,))}
edge_features = {"x": torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]
assert torch.equal(g.edges()[0], node_pairs[0])
assert torch.equal(g.edges()[1], node_pairs[1])
assert torch.equal(g.ndata[dgl.NID], reverse_row_node_ids)
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids)
assert torch.equal(g.ndata["x"], node_features["x"])
assert torch.equal(g.edata["x"], edge_features["x"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment