"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c8d86e9f0a2791eb0d08b2692803cab5ea7a35e2"
Commit 1e46a92c authored by rusty1s's avatar rusty1s
Browse files

better sort impl

parent e0a293c3
...@@ -4,15 +4,13 @@ from torch_cluster.functions.utils.permute import sort, permute ...@@ -4,15 +4,13 @@ from torch_cluster.functions.utils.permute import sort, permute
def test_sort_cpu(): def test_sort_cpu():
edge_index = torch.LongTensor([ row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3], col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2], row, col = sort(row, col)
]) expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
expected_edge_index = [ expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
[0, 0, 1, 1, 1, 2, 2, 2, 3, 3], assert row.tolist() == expected_row
[1, 2, 0, 2, 3, 0, 1, 3, 1, 2], assert col.tolist() == expected_col
]
assert sort(edge_index).tolist() == expected_edge_index
def test_permute_cpu(): def test_permute_cpu():
...@@ -34,12 +32,13 @@ def test_permute_cpu(): ...@@ -34,12 +32,13 @@ def test_permute_cpu():
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_sort_gpu(): # pragma: no cover def test_sort_gpu(): # pragma: no cover
edge_index = torch.cuda.LongTensor([ row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3], col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2], row, col = sort(row, col)
])
expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3] expected_row = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
assert sort(edge_index)[0].cpu().tolist() == expected_row expected_col = [1, 2, 0, 2, 3, 0, 1, 3, 1, 2]
assert row.cpu().tolist() == expected_row
assert col.cpu().tolist() == expected_col
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
......
import torch import torch
def sort(edge_index): def sort(row, col):
row, col = edge_index
row, perm = row.sort() row, perm = row.sort()
col = col[perm] col = col[perm]
return torch.stack([row, col], dim=0) return row, col
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None): def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
...@@ -22,7 +21,7 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None): ...@@ -22,7 +21,7 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
row = node_rid[row] row = node_rid[row]
# Sort row and column indices based on changed values. # Sort row and column indices based on changed values.
row, col = sort(torch.stack([row, col], dim=0)) row, col = sort(row, col)
# Revert previous row value changes to old indices. # Revert previous row value changes to old indices.
row = node_rid.sort()[1][row] row = node_rid.sort()[1][row]
......
...@@ -20,4 +20,3 @@ void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col ...@@ -20,4 +20,3 @@ void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment