Commit f60388a0 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 3f1346dc
......@@ -8,6 +8,6 @@ def scatter_div_(output, index, input, dim=0):
return scatter('div', dim, output, index, input)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_div(index, input, dim=0, size=None, fill_value=1):
output = gen_output(index, input, dim, size, fill_value)
scatter_div_(output, index, input, dim)
......@@ -12,6 +12,6 @@ def scatter_max_(output, index, input, dim=0):
return scatter('max', dim, output, index, input, arg_output)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_max(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, size, fill_value)
return scatter_max_(output, index, input, dim)
......@@ -12,6 +12,6 @@ def scatter_mean_(output, index, input, dim=0):
return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_mean(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, size, fill_value)
return scatter_mean_(output, index, input, dim)
......@@ -9,6 +9,6 @@ def scatter_min_(output, index, input, dim=0):
return scatter('min', dim, output, index, input, arg_output)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_min(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, size, fill_value)
return scatter_min_(output, index, input, dim)
......@@ -8,6 +8,6 @@ def scatter_mul_(output, index, input, dim=0):
return scatter('mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_mul(index, input, dim=0, size=None, fill_value=1):
output = gen_output(index, input, dim, size, fill_value)
return scatter_mul_(output, index, input, dim)
......@@ -39,12 +39,12 @@ def _scatter(name, dim, *data):
return (data[0], data[3]) if has_arg_output(name) else data[0]
def index_backward(dim, index, grad, arg_grad):
def index_backward(dim, index, grad, arg):
typename = type(grad).__name__.replace('Tensor', '')
cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg_grad)
func(dim, output, index, grad, arg)
return output
......
......@@ -7,6 +7,6 @@ def scatter_sub_(output, index, input, dim=0):
return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value)
def scatter_sub(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, size, fill_value)
return scatter_sub_(output, index, input, dim)
......@@ -9,8 +9,11 @@ def gen_filled_tensor(input, size, fill_value):
return Variable(input.data.new(size).fill_(fill_value))
def gen_output(index, input, dim, max_index, fill_value):
max_index = index.max() + 1 if max_index is None else max_index
def gen_output(index, input, dim, dim_size, fill_value):
if dim_size is None:
dim_size = index.max() + 1
dim_size = dim.size if torch.is_tensor(input) else dim_size.data[0]
size = list(index.size())
size[dim] = max_index if torch.is_tensor(input) else max_index.data[0]
size[dim] = dim_size
return gen_filled_tensor(input, torch.Size(size), fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment