Commit 0466dd06 authored by rusty1s's avatar rusty1s
Browse files

max test

parent 1c4ef780
import pytest
import torch
from torch.autograd import Variable
from torch_scatter import scatter_max_, scatter_max
from .utils import tensor_strs, Tensor
@pytest.mark.parametrize('str', tensor_strs)
def test_scatter_mean(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = Tensor(str, input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]
expected_output_index = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
_, output_index = scatter_max_(output, index, input, dim=1)
assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index
output, output_index = scatter_max(index, input, dim=1)
assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index
output = Variable(output).fill_(0)
index = Variable(index)
input = Variable(input, requires_grad=True)
_, output_index = scatter_max_(output, index, input, dim=1)
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output = Tensor(str, grad_output)
output.backward(grad_output)
assert index.data.tolist() == input.grad.data.tolist()
...@@ -62,7 +62,10 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): ...@@ -62,7 +62,10 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0): def scatter_max_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1) if torch.is_tensor(input):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
scatter('max', dim, output, index, input, output_index) scatter('max', dim, output, index, input, output_index)
return output, output_index return output, output_index
...@@ -73,7 +76,10 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0): ...@@ -73,7 +76,10 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0): def scatter_min_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1) if torch.is_tensor(input):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
scatter('min', dim, output, index, input, output_index) scatter('min', dim, output, index, input, output_index)
return output, output_index return output, output_index
......
...@@ -33,6 +33,7 @@ class _Scatter(Function): ...@@ -33,6 +33,7 @@ class _Scatter(Function):
if self.needs_input_grad[0]: if self.needs_input_grad[0]:
grad_output = data[0] grad_output = data[0]
if self.needs_input_grad[2]: if self.needs_input_grad[2]:
# TODO: max and min
grad_input = data[0].gather(self.dim, index.data) grad_input = data[0].gather(self.dim, index.data)
return (grad_output, None, grad_input) + (None, ) * (self.len - 3) return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment