"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "7d7ae0a1b0df87ce8ac123cd8b97ade6b15bac2f"
Commit efaa27c1 authored by rusty1s's avatar rusty1s
Browse files

better permute impl

parent 1e46a92c
...@@ -14,20 +14,15 @@ def test_sort_cpu(): ...@@ -14,20 +14,15 @@ def test_sort_cpu():
def test_permute_cpu(): def test_permute_cpu():
edge_index = torch.LongTensor([ row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3], col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
node_rid = torch.LongTensor([2, 1, 3, 0]) node_rid = torch.LongTensor([2, 1, 3, 0])
edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
row, col = permute(row, col, 4, node_rid, edge_rid)
edge_index = permute(edge_index, 4, node_rid, edge_rid) expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
expected_edge_index = [ expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
[3, 3, 1, 1, 1, 0, 0, 2, 2, 2], assert row.tolist() == expected_row
[1, 2, 0, 2, 3, 1, 2, 0, 1, 3], assert col.tolist() == expected_col
]
assert edge_index.tolist() == expected_edge_index
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
...@@ -43,13 +38,12 @@ def test_sort_gpu(): # pragma: no cover ...@@ -43,13 +38,12 @@ def test_sort_gpu(): # pragma: no cover
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu(): # pragma: no cover def test_permute_gpu(): # pragma: no cover
edge_index = torch.cuda.LongTensor([ row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3], col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
node_rid = torch.cuda.LongTensor([2, 1, 3, 0]) node_rid = torch.cuda.LongTensor([2, 1, 3, 0])
edge_rid = torch.cuda.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edge_index = permute(edge_index, 4, node_rid) row, col = permute(row, col, 4, node_rid, edge_rid)
expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2] expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
assert edge_index[0].cpu().tolist() == expected_row assert row.cpu().tolist() == expected_row
assert col.cpu().tolist() == expected_col
...@@ -7,17 +7,17 @@ def sort(row, col): ...@@ -7,17 +7,17 @@ def sort(row, col):
return row, col return row, col
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None): def permute(row, col, num_nodes, node_rid=None, edge_rid=None):
num_edges = edge_index.size(1) num_edges = row.size(0)
# Randomly reorder row and column indices. # Randomly reorder row and column indices.
if edge_rid is None: if edge_rid is None:
edge_rid = torch.randperm(num_edges).type_as(edge_index) edge_rid = torch.randperm(num_edges).type_as(row)
row, col = edge_index[:, edge_rid] row, col = row[edge_rid], col[edge_rid]
# Randomly change row indices to new values. # Randomly change row indices to new values.
if node_rid is None: if node_rid is None:
node_rid = torch.randperm(num_nodes).type_as(edge_index) node_rid = torch.randperm(num_nodes).type_as(row)
row = node_rid[row] row = node_rid[row]
# Sort row and column indices based on changed values. # Sort row and column indices based on changed values.
...@@ -26,4 +26,4 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None): ...@@ -26,4 +26,4 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
# Revert previous row value changes to old indices. # Revert previous row value changes to old indices.
row = node_rid.sort()[1][row] row = node_rid.sort()[1][row]
return torch.stack([row, col], dim=0) return row, col
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment