"tests/vscode:/vscode.git/clone" did not exist on "c022e52923c897d0c16ebcaa9feecfbf1dfbec66"
Commit a80a8ba7 authored by rusty1s's avatar rusty1s
Browse files

first python impl

parent 7f712d86
from .scatter import scatter
from .utils import gen_output
def scatter_add_(output, index, input, dim=0):
return scatter('add', output, index, input, dim)
def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_add_(output, index, input, dim)
def scatter_sub_(output, index, input, dim=0):
return scatter('sub', output, index, input, dim)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_sub_(output, index, input, dim)
def scatter_mul_(output, index, input, dim=0):
return scatter('mul', output, index, input, dim)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_mul_(output, index, input, dim)
def scatter_div_(output, index, input, dim=0):
return scatter('div', output, index, input, dim)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_div_(output, index, input, dim)
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div'
]
import torch
from torch.autograd import Function
from .._ext import ffi
def _scatter(name, output, index, input, dim):
typename = type.__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(output, index, input, dim)
return output
class _Scatter(Function):
def __init__(self, name, dim):
super(_Scatter, self).__init__()
self.dim = dim
self.name = name
def forward(self, output, index, input):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(output)
self.save_for_backward(index)
return _scatter(self.name, output, index, input, self.dim)
def backward(self, grad):
index, = self.saved_variables
grad_output = grad_input = None
if self.needs_input_grad[0]:
grad_output = grad
if self.needs_input_grad[2]:
grad_input = grad.gather(self.dim, index.data)
return grad_output, None, grad_input
def scatter(name, output, index, input, dim):
if torch.is_tensor(input):
return _scatter(name, output, index, input, dim)
else:
return _Scatter(name, dim)(output, index, input)
import torch
from torch.autograd import Variable
def gen_output(index, input, dim, max_index, fill_value):
max_index = index.max() + 1 if max_index is None else max_index
size = list(index.size())
if torch.is_tensor(input):
size[dim] = max_index
return input.new(torch.Size(size)).fill_(fill_value)
else:
size[dim] = max_index.data[0]
return Variable(input.new(torch.Size(size)).fill_(fill_value))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment