Commit 5628a6f6 authored by rusty1s's avatar rusty1s
Browse files

added mean impl

parent d951ab4d
import pytest
import torch
from torch.autograd import Variable
from torch_scatter import scatter_mean_, scatter_mean
from .utils import tensor_strs, Tensor
# @pytest.mark.parametrize('str', tensor_strs)
# def test_scatter_add(str):
def test_scatter_mean():
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = torch.FloatTensor(input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
# expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
scatter_mean_(output, index, input, dim=1)
print(output)
# assert output.tolist() == expected_output
# output = scatter_add(index, input, dim=1)
# assert output.tolist(), expected_output
# output = Variable(output).fill_(0)
# index = Variable(index)
# input = Variable(input, requires_grad=True)
# scatter_add_(output, index, input, dim=1)
# grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
# grad_output = Tensor(str, grad_output)
# output.backward(grad_output)
# assert index.data.tolist() == input.grad.data.tolist()
......@@ -3,7 +3,8 @@ from .utils import gen_output
def scatter_add_(output, index, input, dim=0):
return scatter('add', output, index, input, dim)
scatter('add', dim, output, index, input)
return output
def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
......@@ -12,7 +13,8 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def scatter_sub_(output, index, input, dim=0):
return scatter('sub', output, index, input, dim)
scatter('sub', dim, output, index, input)
return output
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
......@@ -21,7 +23,8 @@ def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
def scatter_mul_(output, index, input, dim=0):
return scatter('mul', output, index, input, dim)
scatter('mul', dim, output, index, input)
return output
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
......@@ -30,15 +33,30 @@ def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
def scatter_div_(output, index, input, dim=0):
return scatter('div', output, index, input, dim)
scatter('div', dim, output, index, input)
return output
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_div_(output, index, input, dim)
scatter_div_(output, index, input, dim)
def scatter_mean_(output, index, input, dim=0):
output_count = output.new(output.size()).fill_(0)
scatter('mean', dim, output, index, input, output_count)
output /= output_count
output[output != output] = 0
return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mean_(output, index, input, dim)
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div'
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div',
'scatter_mean_', 'scatter_mean'
]
......@@ -4,11 +4,10 @@ from torch.autograd import Function
from .._ext import ffi
def _scatter(name, output, index, input, dim):
typename = type(input).__name__.replace('Tensor', '')
def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(output, index, input, dim)
return output
func(dim, *data)
class _Scatter(Function):
......@@ -17,13 +16,14 @@ class _Scatter(Function):
self.dim = dim
self.name = name
def forward(self, output, index, input):
def forward(self, *data):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(output)
self.save_for_backward(index)
self.mark_dirty(data[0])
self.save_for_backward(data[1])
return _scatter(self.name, output, index, input, self.dim)
_scatter(self.name, self.dim, *data)
return data[0]
def backward(self, grad):
index, = self.saved_variables
......@@ -37,8 +37,8 @@ class _Scatter(Function):
return grad_output, None, grad_input
def scatter(name, output, index, input, dim):
if torch.is_tensor(input):
return _scatter(name, output, index, input, dim)
def scatter(name, dim, *data):
if torch.is_tensor(data[0]):
return _scatter(name, dim, *data)
else:
return _Scatter(name, dim)(output, index, input)
return _Scatter(name, dim)(*data)
......@@ -44,10 +44,10 @@
THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \
THDescBuff T4buff = _THSizeDesc(TENSOR4->size, TENSOR3->nDimension); \
THError("inconsistent tensor size, expected %s %s, %s %s, %s %s and %s %s to have the same " \
"number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, #TENSOR4, T4.buff.str); \
"number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, #TENSOR4, T4buff.str); \
} \
\
SIZE_CHECK(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \
SIZE_CHECK(TENSOR1, TENSOR2, TENSOR3, TENSOR4, DIMENSION) \
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
......
#include <TH/TH.h>
#include "THTensorDimApply.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
......
void scatter_add_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_add_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
void scatter_add_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_add_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_add_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_add_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_add_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_add_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_add_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_add_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_add_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_add_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_add_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_add_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_sub_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_sub_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
void scatter_sub_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_sub_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_sub_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_sub_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_sub_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_sub_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_sub_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_sub_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_sub_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_sub_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_sub_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_sub_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mul_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_mul_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
void scatter_mul_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_mul_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_mul_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_mul_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_mul_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_mul_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_mul_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_mul_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_mul_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_mul_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_mul_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_mul_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_div_Float (THFloatTensor *output, THLongTensor *index, THFloatTensor *input, int dim);
void scatter_div_Double(THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, int dim);
void scatter_div_Byte (THByteTensor *output, THLongTensor *index, THByteTensor *input, int dim);
void scatter_div_Char (THCharTensor *output, THLongTensor *index, THCharTensor *input, int dim);
void scatter_div_Short (THShortTensor *output, THLongTensor *index, THShortTensor *input, int dim);
void scatter_div_Int (THIntTensor *output, THLongTensor *index, THIntTensor *input, int dim);
void scatter_div_Long (THLongTensor *output, THLongTensor *index, THLongTensor *input, int dim);
void scatter_div_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_div_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_div_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_div_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_div_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_div_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mean_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THFloatTensor *output_count);
void scatter_mean_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THDoubleTensor *output_count);
void scatter_mean_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THByteTensor *output_count);
void scatter_mean_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THCharTensor *output_count);
void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THShortTensor *output_count);
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *output_count);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_count);
......@@ -2,9 +2,9 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
void scatter_(add)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
......@@ -12,9 +12,9 @@ void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
}
void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
void scatter_(sub)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
......@@ -22,9 +22,9 @@ void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
}
void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
......@@ -32,9 +32,9 @@ void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
}
void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int dim) {
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, real, input, int64_t, index, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
......@@ -42,4 +42,15 @@ void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
}
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *output_count) {
int64_t idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, output_count, dim, TH_TENSOR_DIM_APPLY4_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
output_count_data[idx]++;
})
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment