Commit d15d07b1 authored by rusty1s's avatar rusty1s
Browse files

added max and min tests

parent 1f019c20
...@@ -38,5 +38,23 @@ ...@@ -38,5 +38,23 @@
"dim": 1, "dim": 1,
"fill_value": 0, "fill_value": 0,
"expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]] "expected": [[0, 0, 4, 3, 1.5, 0], [1, 4, 2, 0, 0, 0]]
},
{
"name": "max",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 0,
"expected": [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]],
"expected_arg": [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
},
{
"name": "min",
"index": [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]],
"input": [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]],
"dim": 1,
"fill_value": 9,
"expected": [[9, 9, 4, 3, 1, 0], [0, 4, 1, 9, 9, 9]],
"expected_arg": [[-1, -1, 3, 4, 2, 1], [0, 4, 2, -1, -1, -1]]
} }
] ]
...@@ -33,7 +33,7 @@ def test_backward_cpu(tensor, i): ...@@ -33,7 +33,7 @@ def test_backward_cpu(tensor, i):
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data)))) @pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_backward_gpu(tensor, i): def test_backward_gpu(tensor, i): # pragma: no cover
name = data[i]['name'] name = data[i]['name']
index = V(torch.LongTensor(data[i]['index'])).cuda() index = V(torch.LongTensor(data[i]['index'])).cuda()
input = V(Tensor(tensor, data[i]['input']), requires_grad=True).cuda() input = V(Tensor(tensor, data[i]['input']), requires_grad=True).cuda()
......
...@@ -24,17 +24,25 @@ def test_forward_cpu(tensor, i): ...@@ -24,17 +24,25 @@ def test_forward_cpu(tensor, i):
output = expected.new(expected.size()).fill_(fill_value) output = expected.new(expected.size()).fill_(fill_value)
func = getattr(torch_scatter, 'scatter_{}_'.format(name)) func = getattr(torch_scatter, 'scatter_{}_'.format(name))
func(output, index, input, dim) result = func(output, index, input, dim)
assert output.tolist() == expected.tolist() assert output.tolist() == expected.tolist()
if 'expected_arg' in data[i]:
expected_arg = torch.LongTensor(data[i]['expected_arg'])
assert result[1].tolist() == expected_arg.tolist()
func = getattr(torch_scatter, 'scatter_{}'.format(name)) func = getattr(torch_scatter, 'scatter_{}'.format(name))
output = func(index, input, dim, fill_value=fill_value) result = func(index, input, dim, fill_value=fill_value)
assert output.tolist() == expected.tolist() if 'expected_arg' not in data[i]:
assert result.tolist() == expected.tolist()
else:
expected_arg = torch.LongTensor(data[i]['expected_arg'])
assert result[0].tolist() == expected.tolist()
assert result[1].tolist() == expected_arg.tolist()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data)))) @pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_forward_gpu(tensor, i): def test_forward_gpu(tensor, i): # pragma: no cover
name = data[i]['name'] name = data[i]['name']
index = torch.LongTensor(data[i]['index']).cuda() index = torch.LongTensor(data[i]['index']).cuda()
input = Tensor(tensor, data[i]['input']).cuda() input = Tensor(tensor, data[i]['input']).cuda()
...@@ -44,9 +52,17 @@ def test_forward_gpu(tensor, i): ...@@ -44,9 +52,17 @@ def test_forward_gpu(tensor, i):
output = expected.new(expected.size()).fill_(fill_value).cuda() output = expected.new(expected.size()).fill_(fill_value).cuda()
func = getattr(torch_scatter, 'scatter_{}_'.format(name)) func = getattr(torch_scatter, 'scatter_{}_'.format(name))
func(output, index, input, dim) result = func(output, index, input, dim)
assert output.cpu().tolist() == expected.tolist() assert output.cpu().tolist() == expected.tolist()
if 'expected_arg' in data[i]:
expected_arg = torch.LongTensor(data[i]['expected_arg'])
assert result[1].cpu().tolist() == expected_arg.tolist()
func = getattr(torch_scatter, 'scatter_{}'.format(name)) func = getattr(torch_scatter, 'scatter_{}'.format(name))
output = func(index, input, dim, fill_value=fill_value) result = func(index, input, dim, fill_value=fill_value)
assert output.cpu().tolist() == expected.tolist() if 'expected_arg' not in data[i]:
assert result.cpu().tolist() == expected.tolist()
else:
expected_arg = torch.LongTensor(data[i]['expected_arg'])
assert result[0].cpu().tolist() == expected.tolist()
assert result[1].cpu().tolist() == expected_arg.tolist()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment