"vscode:/vscode.git/clone" did not exist on "8fd3a74322befbd13bb461e4cb9e1a57f6e9ed96"
Commit 1e46a92c authored by rusty1s's avatar rusty1s
Browse files

better sort impl

parent e0a293c3
......@@ -4,15 +4,13 @@ from torch_cluster.functions.utils.permute import sort, permute
def test_sort_cpu():
edge_index = torch.LongTensor([
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
expected_edge_index = [
[0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
[1, 2, 0, 2, 3, 0, 1, 3, 1, 2],
]
assert sort(edge_index).tolist() == expected_edge_index
row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
row, col = sort(row, col)
expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
assert row.tolist() == expected_row
assert col.tolist() == expected_col
def test_permute_cpu():
......@@ -34,12 +32,13 @@ def test_permute_cpu():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu(): # pragma: no cover
edge_index = torch.cuda.LongTensor([
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
row, col = sort(row, col)
expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
assert sort(edge_index)[0].cpu().tolist() == expected_row
expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
assert row.cpu().tolist() == expected_row
assert col.cpu().tolist() == expected_col
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
......
import torch
def sort(edge_index):
row, col = edge_index
def sort(row, col):
row, perm = row.sort()
col = col[perm]
return torch.stack([row, col], dim=0)
return row, col
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
......@@ -22,7 +21,7 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
row = node_rid[row]
# Sort row and column indices based on changed values.
row, col = sort(torch.stack([row, col], dim=0))
row, col = sort(row, col)
# Revert previous row value changes to old indices.
row = node_rid.sort()[1][row]
......
......@@ -20,4 +20,3 @@ void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment