"docs/vscode:/vscode.git/clone" did not exist on "76d492ea49342b486dfbca1dbcdfbb052fe34112"
Unverified Commit 858c0b86 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update `to_pyg_data()` to support get batch_size from `seeds`. (#7214)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 396b0ec2
......@@ -536,6 +536,11 @@ class MiniBatch:
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
elif self.seeds is not None:
if isinstance(self.seeds, Dict):
batch_size = len(next(iter(self.seeds.values())))
else:
batch_size = len(self.seeds)
else:
batch_size = None
pyg_data = Data(
......
......@@ -868,7 +868,7 @@ def test_dgl_link_predication_hetero(mode):
)
def test_to_pyg_data():
def test_to_pyg_data_original():
test_minibatch = create_homo_minibatch()
test_minibatch.seed_nodes = torch.tensor([0, 1])
test_minibatch.labels = torch.tensor([7, 8])
......@@ -936,3 +936,82 @@ def test_to_pyg_data():
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)
def test_to_pyg_data():
test_minibatch = create_homo_minibatch()
test_minibatch.seeds = torch.tensor([0, 1])
test_minibatch.labels = torch.tensor([7, 8])
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
)
expected_node_features = next(iter(test_minibatch.node_features.values()))
expected_labels = torch.tensor([7, 8])
expected_batch_size = 2
expected_n_id = torch.tensor([10, 11, 12, 13])
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)
assert pyg_data.batch_size == expected_batch_size
assert torch.equal(pyg_data.n_id, expected_n_id)
test_minibatch.seeds = torch.tensor([[0, 1], [2, 3]])
assert pyg_data.batch_size == expected_batch_size
test_minibatch.seeds = {"A": torch.tensor([0, 1])}
assert pyg_data.batch_size == expected_batch_size
test_minibatch.seeds = {"A": torch.tensor([[0, 1], [2, 3]])}
assert pyg_data.batch_size == expected_batch_size
subgraph = test_minibatch.sampled_subgraphs[0]
# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.edge_index is None, "Edge index should be none."
# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=None,
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.x is None, "Node features should be None."
# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={"feat": expected_node_features},
labels=None,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.y is None, "Labels should be None."
# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
},
labels=expected_labels,
)
try:
pyg_data = test_minibatch.to_pyg_data()
assert (
pyg_data.x is None
), "Multiple features case should raise an error."
except AssertionError as e:
assert (
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment