Commit efaa27c1 authored by rusty1s's avatar rusty1s
Browse files

better permute impl

parent 1e46a92c
......@@ -14,20 +14,15 @@ def test_sort_cpu():
def test_permute_cpu():
edge_index = torch.LongTensor([
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
row = torch.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
node_rid = torch.LongTensor([2, 1, 3, 0])
edge_rid = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edge_index = permute(edge_index, 4, node_rid, edge_rid)
expected_edge_index = [
[3, 3, 1, 1, 1, 0, 0, 2, 2, 2],
[1, 2, 0, 2, 3, 1, 2, 0, 1, 3],
]
assert edge_index.tolist() == expected_edge_index
row, col = permute(row, col, 4, node_rid, edge_rid)
expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
assert row.tolist() == expected_row
assert col.tolist() == expected_col
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
......@@ -43,13 +38,12 @@ def test_sort_gpu(): # pragma: no cover
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
def test_permute_gpu(): # pragma: no cover
edge_index = torch.cuda.LongTensor([
[0, 1, 0, 2, 1, 2, 1, 3, 2, 3],
[1, 0, 2, 0, 2, 1, 3, 1, 3, 2],
])
row = torch.cuda.LongTensor([0, 1, 0, 2, 1, 2, 1, 3, 2, 3])
col = torch.cuda.LongTensor([1, 0, 2, 0, 2, 1, 3, 1, 3, 2])
node_rid = torch.cuda.LongTensor([2, 1, 3, 0])
edge_index = permute(edge_index, 4, node_rid)
edge_rid = torch.cuda.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
row, col = permute(row, col, 4, node_rid, edge_rid)
expected_row = [3, 3, 1, 1, 1, 0, 0, 2, 2, 2]
assert edge_index[0].cpu().tolist() == expected_row
expected_col = [1, 2, 0, 2, 3, 1, 2, 0, 1, 3]
assert row.cpu().tolist() == expected_row
assert col.cpu().tolist() == expected_col
......@@ -7,17 +7,17 @@ def sort(row, col):
return row, col
def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
num_edges = edge_index.size(1)
def permute(row, col, num_nodes, node_rid=None, edge_rid=None):
num_edges = row.size(0)
# Randomly reorder row and column indices.
if edge_rid is None:
edge_rid = torch.randperm(num_edges).type_as(edge_index)
row, col = edge_index[:, edge_rid]
edge_rid = torch.randperm(num_edges).type_as(row)
row, col = row[edge_rid], col[edge_rid]
# Randomly change row indices to new values.
if node_rid is None:
node_rid = torch.randperm(num_nodes).type_as(edge_index)
node_rid = torch.randperm(num_nodes).type_as(row)
row = node_rid[row]
# Sort row and column indices based on changed values.
......@@ -26,4 +26,4 @@ def permute(edge_index, num_nodes, node_rid=None, edge_rid=None):
# Revert previous row value changes to old indices.
row = node_rid.sort()[1][row]
return torch.stack([row, col], dim=0)
return row, col
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment