Unverified Commit 7bcc27ff authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add to dgl (#6381)

parent 17198e9e
...@@ -209,10 +209,12 @@ class MiniBatch: ...@@ -209,10 +209,12 @@ class MiniBatch:
all node ids inside are compacted. all node ids inside are compacted.
""" """
def to_dgl_blocks(self): def __repr__(self) -> str:
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing a return _minibatch_str(self)
graphical structure and assigning features to the nodes and edges within
the blocks. def _to_dgl_blocks(self):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
""" """
if not self.sampled_subgraphs: if not self.sampled_subgraphs:
return None return None
...@@ -257,23 +259,6 @@ class MiniBatch: ...@@ -257,23 +259,6 @@ class MiniBatch:
) )
if is_heterogeneous: if is_heterogeneous:
# Assign node features to the outermost layer's source nodes.
if self.node_features:
for (
node_type,
feature_name,
), feature in self.node_features.items():
blocks[0].srcnodes[node_type].data[feature_name] = feature
# Assign edge features.
if self.edge_features:
for block, edge_feature in zip(blocks, self.edge_features):
for (
edge_type,
feature_name,
), feature in edge_feature.items():
block.edges[etype_str_to_tuple(edge_type)].data[
feature_name
] = feature
# Assign reverse node ids to the outermost layer's source nodes. # Assign reverse node ids to the outermost layer's source nodes.
for node_type, reverse_ids in self.sampled_subgraphs[ for node_type, reverse_ids in self.sampled_subgraphs[
0 0
...@@ -290,15 +275,6 @@ class MiniBatch: ...@@ -290,15 +275,6 @@ class MiniBatch:
dgl.EID dgl.EID
] = reverse_ids ] = reverse_ids
else: else:
# Assign node features to the outermost layer's source nodes.
if self.node_features:
for feature_name, feature in self.node_features.items():
blocks[0].srcdata[feature_name] = feature
# Assign edge features.
if self.edge_features:
for block, edge_feature in zip(blocks, self.edge_features):
for feature_name, feature in edge_feature.items():
block.edata[feature_name] = feature
blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[ blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
0 0
].original_row_node_ids ].original_row_node_ids
...@@ -306,11 +282,96 @@ class MiniBatch: ...@@ -306,11 +282,96 @@ class MiniBatch:
for block, subgraph in zip(blocks, self.sampled_subgraphs): for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.original_edge_ids is not None: if subgraph.original_edge_ids is not None:
block.edata[dgl.EID] = subgraph.original_edge_ids block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks return blocks
def __repr__(self) -> str: def to_dgl(self):
return _minibatch_str(self) """Converting a `MiniBatch` into a DGL MiniBatch that contains
everything necessary for computation."
"""
minibatch = DGLMiniBatch(
blocks=self._to_dgl_blocks(),
input_nodes=self.input_nodes,
output_nodes=self.seed_nodes,
node_features=self.node_features,
edge_features=self.edge_features,
labels=self.labels,
)
assert (
minibatch.blocks is not None
), "Sampled subgraphs for computation are missing."
# For link prediction tasks.
if self.compacted_node_pairs is not None:
minibatch.positive_node_pairs = self.compacted_node_pairs
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
minibatch.negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
)
# For heterogeneous graph.
else:
minibatch.negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_srcs is not None:
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
minibatch.negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
negative_ratio
),
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_srcs.values()
)[0].size(1)
minibatch.negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][
1
].repeat_interleave(negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_dsts is not None:
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
minibatch.negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
negative_ratio
),
self.compacted_negative_dsts.view(-1),
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_dsts.values()
)[0].size(1)
minibatch.negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][
0
].repeat_interleave(negative_ratio),
neg_dst.view(-1),
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
return minibatch
def _minibatch_str(minibatch: MiniBatch) -> str: def _minibatch_str(minibatch: MiniBatch) -> str:
......
import dgl import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest
import torch import torch
def test_to_dgl_blocks_hetero(): relation = "A:r:B"
relation = "A:r:B" reverse_relation = "B:rr:A"
reverse_relation = "B:rr:A"
def create_homo_minibatch():
node_pairs = [ node_pairs = [
{ (
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])), torch.tensor([0, 1, 2, 2, 2, 1]),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])), torch.tensor([0, 1, 1, 2, 3, 2]),
}, ),
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))}, (
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
] ]
original_column_node_ids = [ original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])}, torch.tensor([10, 11, 12, 13]),
{"B": torch.tensor([10, 11])}, torch.tensor([10, 11]),
] ]
original_row_node_ids = [ original_row_node_ids = [
{ torch.tensor([10, 11, 12, 13]),
"A": torch.tensor([5, 7, 9, 11]), torch.tensor([10, 11, 12]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
] ]
original_edge_ids = [ original_edge_ids = [
{ torch.tensor([19, 20, 21, 22, 25, 30]),
relation: torch.tensor([19, 20, 21]), torch.tensor([10, 15, 17]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
] ]
node_features = { node_features = {"x": torch.randint(0, 10, (4,))}
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [ edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))}, {"x": torch.randint(0, 10, (6,))},
{(relation, "x"): torch.randint(0, 10, (2,))}, {"x": torch.randint(0, 10, (3,))},
] ]
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
...@@ -51,68 +46,48 @@ def test_to_dgl_blocks_hetero(): ...@@ -51,68 +46,48 @@ def test_to_dgl_blocks_hetero():
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
blocks = gb.MiniBatch( return gb.MiniBatch(
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
).to_dgl_blocks()
etype = gb.etype_str_to_tuple(relation)
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
assert torch.equal(
block.edges[etype].data["x"],
edge_features[i][(relation, "x")],
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
) )
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
assert torch.equal(
blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")]
)
test_to_dgl_blocks_hetero()
def test_to_dgl_blocks_homo(): def create_hetero_minibatch():
node_pairs = [ node_pairs = [
( {
torch.tensor([0, 1, 2, 2, 2, 1]), relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
torch.tensor([0, 1, 1, 2, 3, 2]), reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
), },
( {relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
] ]
original_column_node_ids = [ original_column_node_ids = [
torch.tensor([10, 11, 12, 13]), {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
torch.tensor([10, 11]), {"B": torch.tensor([10, 11])},
] ]
original_row_node_ids = [ original_row_node_ids = [
torch.tensor([10, 11, 12, 13]), {
torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
] ]
original_edge_ids = [ original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]), {
torch.tensor([10, 15, 17]), relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
] ]
node_features = {"x": torch.randint(0, 10, (4,))} node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [ edge_features = [
{"x": torch.randint(0, 10, (6,))}, {(relation, "x"): torch.randint(0, 10, (3,))},
{"x": torch.randint(0, 10, (3,))}, {(relation, "x"): torch.randint(0, 10, (2,))},
] ]
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
...@@ -124,19 +99,11 @@ def test_to_dgl_blocks_homo(): ...@@ -124,19 +99,11 @@ def test_to_dgl_blocks_homo():
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
blocks = gb.MiniBatch( return gb.MiniBatch(
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
).to_dgl_blocks() )
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
def test_representation(): def test_representation():
...@@ -251,3 +218,164 @@ def test_representation(): ...@@ -251,3 +218,164 @@ def test_representation():
) )
result = str(minibatch) result = str(minibatch)
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(expect_result, result)
def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
def check_dgl_blocks_homo(minibatch, blocks):
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_to_dgl_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_homo(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
dgl_minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
dgl_minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1),
)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_hetero(mode):
# Arrange
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
reverse_relation: (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
),
}
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = {
relation: torch.tensor([[2, 0], [1, 2]]),
reverse_relation: torch.tensor([[1, 2], [0, 2]]),
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][0],
src.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
)
...@@ -142,8 +142,7 @@ def test_SubgraphSampler_Node_Hetero(labor): ...@@ -142,8 +142,7 @@ def test_SubgraphSampler_Node_Hetero(labor):
sampler_dp = Sampler(item_sampler, graph, fanouts) sampler_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2 assert len(list(sampler_dp)) == 2
for minibatch in sampler_dp: for minibatch in sampler_dp:
blocks = minibatch.to_dgl_blocks() assert len(minibatch.sampled_subgraphs) == num_layer
assert len(blocks) == num_layer
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment