"tests/vscode:/vscode.git/clone" did not exist on "ec9e5264c0865c739602f6cab0bef118892a50f3"
Commit f0fdfe20 authored by rusty1s's avatar rusty1s
Browse files

correct gradient computations

parent 43432f49
......@@ -6,7 +6,7 @@ from .functions.mean import scatter_mean_, scatter_mean
from .functions.max import scatter_max_, scatter_max
from .functions.min import scatter_min_, scatter_min
__version__ = '0.2.3'
__version__ = '0.3.0'
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
......
from .scatter import scatter
from .scatter import Scatter, scatter
from .utils import gen_output
class ScatterDiv(Scatter):
def __init__(self, dim):
super(ScatterDiv, self).__init__('div', dim)
def save_for_backward_step(self, *data):
output, index, input = data
self.save_for_backward(output, index, input)
def backward_step(self, *data):
grad, output, index, input = data
return (grad / output.data).gather(self.dim, index.data) * input.data
def scatter_div_(output, index, input, dim=0):
r"""
|
......@@ -53,7 +66,7 @@ def scatter_div_(output, index, input, dim=0):
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
[torch.FloatTensor of size 2x6]
"""
return scatter('div', dim, output, index, input)
return scatter(ScatterDiv, 'div', dim, output, index, input)
def scatter_div(index, input, dim=0, size=None, fill_value=1):
......
from itertools import chain
from .._ext import ffi
def scatter(name, dim, *data):
# data = output, index, input, additional data
a, b, c = data[:3]
# Assert index dimension is valid.
assert dim >= 0 and dim < b.dim(), 'Index dimension is out of bounds'
# Assert same dimensionality across all inputs.
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
'input tensor')
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
'output tensor')
# Assert same tensor length across index and input.
assert b.numel() == c.numel(), ('Index tensor must have same size as '
'input tensor')
# Assert same tensor sizes across input and output apart from `dim`.
for d in chain(range(dim), range(dim + 1, a.dim())):
assert a.size(d) == c.size(d), (
'Input tensor must have same size as output tensor apart from the '
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '')
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
if len(data) <= 3:
return data[0]
return (data[0], ) + tuple(data[3:])
def index_backward(dim, index, grad, arg): # pragma: no cover
typename = type(grad).__name__.replace('Tensor', '')
cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg)
return output
from .scatter import scatter
from .scatter import Scatter, scatter
from .ffi import index_backward
from .utils import gen_filled_tensor, gen_output
class ScatterMax(Scatter):
def __init__(self, dim):
super(ScatterMax, self).__init__('max', dim)
def save_for_backward_step(self, *data):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data):
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
def scatter_max_(output, index, input, dim=0):
r"""
|
......@@ -61,7 +75,7 @@ def scatter_max_(output, index, input, dim=0):
)
"""
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, arg)
return scatter(ScatterMax, 'max', dim, output, index, input, arg)
def scatter_max(index, input, dim=0, size=None, fill_value=0):
......
from __future__ import division
from .scatter import scatter
from .scatter import Scatter, scatter
from .utils import gen_filled_tensor, gen_output
class ScatterMean(Scatter):
def __init__(self, dim):
super(ScatterMean, self).__init__('mean', dim)
def save_for_backward_step(self, *data):
output, index, input, count = data
self.save_for_backward(index)
def backward_step(self, *data):
grad, index = data
return grad.gather(self.dim, index.data)
def scatter_mean_(output, index, input, dim=0):
r"""
|
......@@ -56,10 +69,12 @@ def scatter_mean_(output, index, input, dim=0):
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
init = gen_filled_tensor(output, output.size(), fill_value=0)
count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, count)
scatter(ScatterMean, 'mean', dim, init, index, input, count)
count[count == 0] = 1
output /= count
init /= count
output += init
return output
......
from .scatter import scatter
from .scatter import Scatter, scatter
from .ffi import index_backward
from .utils import gen_filled_tensor, gen_output
class ScatterMin(Scatter):
def __init__(self, dim):
super(ScatterMin, self).__init__('min', dim)
def save_for_backward_step(self, *data):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data):
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
def scatter_min_(output, index, input, dim=0):
r"""
|
......@@ -61,7 +75,7 @@ def scatter_min_(output, index, input, dim=0):
)
"""
arg = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, arg)
return scatter(ScatterMin, 'min', dim, output, index, input, arg)
def scatter_min(index, input, dim=0, size=None, fill_value=0):
......
from .scatter import scatter
from .scatter import Scatter, scatter
from .utils import gen_output
class ScatterMul(Scatter):
def __init__(self, dim):
super(ScatterMul, self).__init__('mul', dim)
def save_for_backward_step(self, *data):
output, index, input = data
self.save_for_backward(output, index, input)
def backward_step(self, *data):
grad, output, index, input = data
return (grad * output.data).gather(self.dim, index.data) / input.data
def scatter_mul_(output, index, input, dim=0):
r"""
|
......@@ -52,7 +65,7 @@ def scatter_mul_(output, index, input, dim=0):
6 4 8 1 1 1
[torch.FloatTensor of size 2x6]
"""
return scatter('mul', dim, output, index, input)
return scatter(ScatterMul, 'mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, size=None, fill_value=1):
......
from itertools import chain
import torch
from torch.autograd import Function
from .._ext import ffi
def has_arg(name):
return name in ['max', 'min']
def _scatter(name, dim, *data):
a, b, c = data[:3]
# Assert index dimension is valid.
assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
# Assert same dimensionality across all inputs.
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
'input tensor')
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
'output tensor')
# Assert same tensor length across index and input.
assert b.numel() == c.numel(), ('Index tensor must have same size as '
'input tensor')
# Assert same tensor sizes across input and output apart from `dim`.
for d in chain(range(dim), range(dim + 1, a.dim())):
assert a.size(d) == c.size(d), (
'Input tensor must have same size as output tensor apart from the '
'specified dimension')
from .ffi import scatter as ffi_scatter
typename = type(data[0]).__name__.replace('Tensor', '')
cuda = 'cuda_' if data[0].is_cuda else ''
func = getattr(ffi, 'scatter_{}_{}{}'.format(name, cuda, typename))
func(dim, *data)
return (data[0], data[3]) if has_arg(name) else data[0]
def index_backward(dim, index, grad, arg): # pragma: no cover
typename = type(grad).__name__.replace('Tensor', '')
cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg)
return output
class _Scatter(Function):
class Scatter(Function):
def __init__(self, name, dim):
super(_Scatter, self).__init__()
super(Scatter, self).__init__()
self.name = name
self.dim = dim
def save_for_backward_step(self, *data):
raise NotImplementedError
def forward(self, *data):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(data[0]) # Mark output as dirty.
self.len = len(data) # Save number of arguments for backward step.
_scatter(self.name, self.dim, *data)
output = ffi_scatter(self.name, self.dim, *data)
self.save_for_backward_step(*data)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. Therefore, we need to save the `arg` for the
# backward pass.
if has_arg(self.name):
self.save_for_backward(data[1], data[3])
return data[0], data[3]
else:
self.save_for_backward(data[1])
return data[0]
return output
def backward(self, *data): # pragma: no cover
grad_output = grad_input = None
......@@ -78,24 +30,53 @@ class _Scatter(Function):
if self.needs_input_grad[0]:
grad_output = data[0]
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if self.needs_input_grad[2] and not has_arg(self.name):
index, = self.saved_variables
grad_input = data[0].gather(self.dim, index.data)
# Call grad computation of `input` for the specific scatter operation.
if self.needs_input_grad[2]:
grad_input = self.backward_step(data[0], *self.saved_variables)
if self.needs_input_grad[2] and has_arg(self.name):
index, arg = self.saved_variables
data = (index.data, data[0], arg.data)
grad_input = index_backward(self.dim, *data)
# Return and fill with empty grads for non-differentiable passed
# arguments in forward pass.
# Return and fill with empty grads for non-differentiable arguments.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
def backward_step(self, *data):
raise NotImplementedError
def scatter(name, dim, *data):
def scatter(Clx, name, dim, *data):
if torch.is_tensor(data[0]):
return _scatter(name, dim, *data)
return ffi_scatter(name, dim, *data)
else:
return _Scatter(name, dim)(*data)
return Clx(dim)(*data)
# def index_backward(dim, index, grad, arg): # pragma: no cover
# typename = type(grad).__name__.replace('Tensor', '')
# cuda = 'cuda_' if grad.is_cuda else ''
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
# output = grad.new(index.size()).fill_(0)
# func(dim, output, index, grad, arg)
# return output
# def _scatter_backward(name, dim, saved, *data):
# # saved = (index, ), (index, arg) or (index, count)
# print(name)
# print(len(data))
# print(len(saved))
# print(saved[1].size())
# # data = (grad, )
# # index, = seved
# if has_arg(name):
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
# if has_count(name):
# return (data[0] / saved[1]).gather(dim, saved[0].data)
# # Different grad computation of `input` if `scatter_max` or
# # `scatter_min` was used.
# # if self.needs_input_grad[2] and not has_arg(self.name):
# # index, = self.saved_variables
# # grad_input = data[0].gather(self.dim, index.data)
# # if self.needs_input_grad[2] and has_arg(self.name):
# # index, arg = self.saved_variables
# # data = (index.data, data[0], arg.data)
# grad_input = index_backward(self.dim, *data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment