Unverified Commit 72b3e078 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Remove redundant data in to_dgl (#6466)

parent 625f8a6b
...@@ -314,12 +314,16 @@ class MiniBatch: ...@@ -314,12 +314,16 @@ class MiniBatch:
""" """
minibatch = DGLMiniBatch( minibatch = DGLMiniBatch(
blocks=self._to_dgl_blocks(), blocks=self._to_dgl_blocks(),
input_nodes=self.input_nodes,
output_nodes=self.seed_nodes,
node_features=self.node_features, node_features=self.node_features,
edge_features=self.edge_features, edge_features=self.edge_features,
labels=self.labels, labels=self.labels,
) )
# Need input nodes to fetch feature.
if self.node_features is None:
minibatch.input_nodes = self.input_nodes
# Need output nodes to fetch label.
if self.labels is None:
minibatch.output_nodes = self.seed_nodes
assert ( assert (
minibatch.blocks is not None minibatch.blocks is not None
), "Sampled subgraphs for computation are missing." ), "Sampled subgraphs for computation are missing."
......
...@@ -50,6 +50,7 @@ def create_homo_minibatch(): ...@@ -50,6 +50,7 @@ def create_homo_minibatch():
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
input_nodes=torch.tensor([10, 11, 12, 13]),
) )
...@@ -103,6 +104,10 @@ def create_hetero_minibatch(): ...@@ -103,6 +104,10 @@ def create_hetero_minibatch():
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
) )
...@@ -286,7 +291,7 @@ def test_dgl_minibatch_representation(): ...@@ -286,7 +291,7 @@ def test_dgl_minibatch_representation():
node_features={'x': tensor([7, 6, 2, 2])}, node_features={'x': tensor([7, 6, 2, 2])},
negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])), negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])),
labels=tensor([0., 1., 2.]), labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]), input_nodes=None,
edge_features=[{'x': tensor([[8], edge_features=[{'x': tensor([[8],
[1], [1],
[6]])}, [6]])},
...@@ -354,6 +359,25 @@ def check_dgl_blocks_homo(minibatch, blocks): ...@@ -354,6 +359,25 @@ def check_dgl_blocks_homo(minibatch, blocks):
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0]) assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_to_dgl_node_classification_without_feature():
# Arrange
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert dgl_minibatch.node_features is None
assert minibatch.edge_features is dgl_minibatch.edge_features
assert dgl_minibatch.labels is None
assert minibatch.input_nodes is dgl_minibatch.input_nodes
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_homo(): def test_to_dgl_node_classification_homo():
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
...@@ -367,7 +391,8 @@ def test_to_dgl_node_classification_homo(): ...@@ -367,7 +391,8 @@ def test_to_dgl_node_classification_homo():
assert minibatch.node_features is dgl_minibatch.node_features assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks) check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
...@@ -382,7 +407,8 @@ def test_to_dgl_node_classification_hetero(): ...@@ -382,7 +407,8 @@ def test_to_dgl_node_classification_hetero():
assert minibatch.node_features is dgl_minibatch.node_features assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks) check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)
......
...@@ -71,7 +71,7 @@ def test_integration_link_prediction(): ...@@ -71,7 +71,7 @@ def test_integration_link_prediction():
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_node_pairs=(tensor([0, 1, 1, 1]), tensor([0, 3, 4, 5])), negative_node_pairs=(tensor([0, 1, 1, 1]), tensor([0, 3, 4, 5])),
labels=None, labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=6, blocks=[Block(num_src_nodes=6,
...@@ -92,7 +92,7 @@ def test_integration_link_prediction(): ...@@ -92,7 +92,7 @@ def test_integration_link_prediction():
[0.6172, 0.7865]])}, [0.6172, 0.7865]])},
negative_node_pairs=(tensor([0, 1, 1, 2]), tensor([1, 3, 4, 1])), negative_node_pairs=(tensor([0, 1, 1, 2]), tensor([1, 3, 4, 1])),
labels=None, labels=None,
input_nodes=tensor([3, 4, 0, 5, 1]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=5, blocks=[Block(num_src_nodes=5,
...@@ -112,7 +112,7 @@ def test_integration_link_prediction(): ...@@ -112,7 +112,7 @@ def test_integration_link_prediction():
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_node_pairs=(tensor([0, 1]), tensor([1, 2])), negative_node_pairs=(tensor([0, 1]), tensor([1, 2])),
labels=None, labels=None,
input_nodes=tensor([5, 4, 3, 0]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=4, blocks=[Block(num_src_nodes=4,
...@@ -193,7 +193,7 @@ def test_integration_node_classification(): ...@@ -193,7 +193,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_node_pairs=None, negative_node_pairs=None,
labels=None, labels=None,
input_nodes=tensor([5, 3, 1, 2, 4, 0]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=6, blocks=[Block(num_src_nodes=6,
...@@ -212,7 +212,7 @@ def test_integration_node_classification(): ...@@ -212,7 +212,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_node_pairs=None, negative_node_pairs=None,
labels=None, labels=None,
input_nodes=tensor([3, 4, 0]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=3, blocks=[Block(num_src_nodes=3,
...@@ -231,7 +231,7 @@ def test_integration_node_classification(): ...@@ -231,7 +231,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_node_pairs=None, negative_node_pairs=None,
labels=None, labels=None,
input_nodes=tensor([5, 4, 0]), input_nodes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
blocks=[Block(num_src_nodes=3, blocks=[Block(num_src_nodes=3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment