Commit 8abdc5f0 authored by rusty1s's avatar rusty1s
Browse files

test max coalesce

parent afccb12e
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torch_sparse import coalesce from torch_sparse import coalesce
def test_coalesce(): def test_coalesce_add():
row = torch.tensor([1, 0, 1, 0, 2, 1]) row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0]) col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
...@@ -11,3 +11,14 @@ def test_coalesce(): ...@@ -11,3 +11,14 @@ def test_coalesce():
index, value = coalesce(index, value, m=3, n=2) index, value = coalesce(index, value, m=3, n=2)
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]] assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]] assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]]
def test_coalesce_max():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = coalesce(index, value, m=3, n=2, op='max')
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[4, 5], [6, 7], [3, 4], [5, 6]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment