Commit 416a2603 authored by rusty1s's avatar rusty1s
Browse files

backward index impl

parent 880c8102
......@@ -27,10 +27,19 @@ def test_scatter_mean(str):
output = Variable(output).fill_(0)
index = Variable(index)
input = Variable(input, requires_grad=True)
_, output_index = scatter_max_(output, index, input, dim=1)
scatter_max_(output, index, input, dim=1)
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]]
grad_output = Tensor(str, grad_output)
output.backward(grad_output)
assert index.data.tolist() == input.grad.data.tolist()
# assert index.data.tolist() == input.grad.data.tolist()
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
index = Variable(torch.LongTensor([3, 4, 4, 2, 1]))
input = Variable(torch.FloatTensor([1, 2, 3, 4, 5]), requires_grad=True)
output, output_index = scatter_max(index, input)
# print(output_index)
output.backward(torch.FloatTensor([10, 20, 30, 40]))
print(input.grad)
......@@ -6,8 +6,7 @@ from .utils import gen_output
def scatter_add_(output, index, input, dim=0):
scatter('add', dim, output, index, input)
return output
return scatter('add', dim, output, index, input)
def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
......@@ -16,8 +15,7 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def scatter_sub_(output, index, input, dim=0):
scatter('sub', dim, output, index, input)
return output
return scatter('sub', dim, output, index, input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
......@@ -26,8 +24,7 @@ def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
def scatter_mul_(output, index, input, dim=0):
scatter('mul', dim, output, index, input)
return output
return scatter('mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
......@@ -36,8 +33,7 @@ def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
def scatter_div_(output, index, input, dim=0):
scatter('div', dim, output, index, input)
return output
return scatter('div', dim, output, index, input)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
......@@ -66,8 +62,7 @@ def scatter_max_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
scatter('max', dim, output, index, input, output_index)
return output, output_index
return scatter('max', dim, output, index, input, output_index)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
......@@ -80,8 +75,7 @@ def scatter_min_(output, index, input, dim=0):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
scatter('min', dim, output, index, input, output_index)
return output, output_index
return scatter('min', dim, output, index, input, output_index)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
......
......@@ -6,6 +6,10 @@ from torch.autograd import Function
from .._ext import ffi
def _has_output_index(name):
return name in ['max', 'min']
def _scatter(name, dim, *data):
a, b, c = data[:3]
......@@ -31,6 +35,15 @@ def _scatter(name, dim, *data):
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data)
return (data[0], data[3]) if _has_output_index(name) else data[0]
def _index_backward(dim, index, grad, grad_index):
typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, grad_index)
return output
class _Scatter(Function):
......@@ -44,21 +57,31 @@ class _Scatter(Function):
self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step
self.save_for_backward(data[1]) # Save index for backward step.
_scatter(self.name, self.dim, *data)
if _has_output_index(self.name):
self.save_for_backward(data[1], data[3])
return data[0], data[3]
else:
self.save_for_backward(data[1])
return data[0]
def backward(self, *data):
index, = self.saved_variables
grad_output = grad_input = None
if self.needs_input_grad[0]:
grad_output = data[0]
if self.needs_input_grad[2]:
# TODO: max and min
if self.needs_input_grad[2] and not _has_output_index(self.name):
index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data)
if self.needs_input_grad[2] and _has_output_index(self.name):
index, grad_index = self.saved_variables
data = (index.data, data[0], grad_index.data)
grad_input = _index_backward(self.dim, *data)
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
......
......@@ -11,4 +11,4 @@ def gen_output(index, input, dim, max_index, fill_value):
return input.new(torch.Size(size)).fill_(fill_value)
else:
size[dim] = max_index.data[0]
return Variable(input.new(torch.Size(size)).fill_(fill_value))
return Variable(input.data.new(torch.Size(size)).fill_(fill_value))
......@@ -3,6 +3,7 @@
#include "THTensorDimApply4.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
#define index_backward TH_CONCAT_2(index_backward_, Real)
inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
......
......@@ -53,3 +53,11 @@ void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, TH
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *output_index);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *output_index);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *output_index);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *grad_index);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *grad_index);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *grad_index);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *grad_index);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *grad_index);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *grad_index);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *grad_index);
......@@ -3,75 +3,72 @@
#else
void scatter_(add)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i];
})
}
void scatter_(sub)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] -= *(input_data + i * input_stride);
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] -= input_data[i];
})
}
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] *= *(input_data + i * input_stride);
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] *= input_data[i];
})
}
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t idx;
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] /= *(input_data + i * input_stride);
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] /= input_data[i];
})
}
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *output_count) {
int64_t idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, output_count, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx] += *(input_data + i * input_stride);
output_count_data[idx]++;
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[index_data[i]] += input_data[i];
output_count_data[index_data[i]]++;
})
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
int64_t idx; real old, new;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
old = output_data[idx]; new = *(input_data + i * input_stride);
if (new >= old) { output_data[idx] = new; output_index_data[idx] = i; }
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] >= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i;
}
})
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *output_index) {
int64_t idx; real old, new;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, output_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
old = output_data[idx]; new = *(input_data + i * input_stride);
if (new <= old) { output_data[idx] = new; output_index_data[idx] = i; }
assertIndexInBoundaries(index_data[i], output_size, TH_TENSOR_DIM_APPLY_counter);
if (input_data[i] <= output_data[index_data[i]]) {
output_data[index_data[i]] = input_data[i];
output_index_data[index_data[i]] = i;
}
})
}
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) {
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
if (grad_index_data[index_data[i]] == i) output_data[index_data[i]] = grad_data[i];
})
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment