"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "90b0ac57b0d8d8f996126deb8bba6b7dc75b4327"
Unverified Commit 2df85862 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Correct `to_pyg_data` (#7124)

parent 13cbad32
...@@ -500,7 +500,7 @@ class MiniBatch: ...@@ -500,7 +500,7 @@ class MiniBatch:
col_nodes = torch.cat(col_nodes) col_nodes = torch.cat(col_nodes)
row_nodes = torch.cat(row_nodes) row_nodes = torch.cat(row_nodes)
edge_index = torch.unique( edge_index = torch.unique(
torch.stack((col_nodes, row_nodes)), dim=1 torch.stack((row_nodes, col_nodes)), dim=1
) )
if self.node_features is None: if self.node_features is None:
......
...@@ -881,7 +881,7 @@ def test_to_pyg_data(): ...@@ -881,7 +881,7 @@ def test_to_pyg_data():
original_column_node_ids=torch.tensor([10, 11]), original_column_node_ids=torch.tensor([10, 11]),
) )
expected_edge_index = torch.tensor( expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 3], [0, 1, 0, 1, 2, 1, 2, 2]] [[0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 0, 1, 2, 1, 2, 3]]
) )
expected_node_features = torch.tensor([[1], [2], [3], [4]]) expected_node_features = torch.tensor([[1], [2], [3], [4]])
expected_labels = torch.tensor([0, 1]) expected_labels = torch.tensor([0, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment