Commit f60388a0 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 3f1346dc
...@@ -8,6 +8,6 @@ def scatter_div_(output, index, input, dim=0): ...@@ -8,6 +8,6 @@ def scatter_div_(output, index, input, dim=0):
return scatter('div', dim, output, index, input) return scatter('div', dim, output, index, input)
def scatter_div(index, input, dim=0, max_index=None, fill_value=1): def scatter_div(index, input, dim=0, size=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
scatter_div_(output, index, input, dim) scatter_div_(output, index, input, dim)
...@@ -12,6 +12,6 @@ def scatter_max_(output, index, input, dim=0): ...@@ -12,6 +12,6 @@ def scatter_max_(output, index, input, dim=0):
return scatter('max', dim, output, index, input, arg_output) return scatter('max', dim, output, index, input, arg_output)
def scatter_max(index, input, dim=0, max_index=None, fill_value=0): def scatter_max(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_max_(output, index, input, dim) return scatter_max_(output, index, input, dim)
...@@ -12,6 +12,6 @@ def scatter_mean_(output, index, input, dim=0): ...@@ -12,6 +12,6 @@ def scatter_mean_(output, index, input, dim=0):
return output return output
def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): def scatter_mean(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_mean_(output, index, input, dim) return scatter_mean_(output, index, input, dim)
...@@ -9,6 +9,6 @@ def scatter_min_(output, index, input, dim=0): ...@@ -9,6 +9,6 @@ def scatter_min_(output, index, input, dim=0):
return scatter('min', dim, output, index, input, arg_output) return scatter('min', dim, output, index, input, arg_output)
def scatter_min(index, input, dim=0, max_index=None, fill_value=0): def scatter_min(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_min_(output, index, input, dim) return scatter_min_(output, index, input, dim)
...@@ -8,6 +8,6 @@ def scatter_mul_(output, index, input, dim=0): ...@@ -8,6 +8,6 @@ def scatter_mul_(output, index, input, dim=0):
return scatter('mul', dim, output, index, input) return scatter('mul', dim, output, index, input)
def scatter_mul(index, input, dim=0, max_index=None, fill_value=1): def scatter_mul(index, input, dim=0, size=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_mul_(output, index, input, dim) return scatter_mul_(output, index, input, dim)
...@@ -39,12 +39,12 @@ def _scatter(name, dim, *data): ...@@ -39,12 +39,12 @@ def _scatter(name, dim, *data):
return (data[0], data[3]) if has_arg_output(name) else data[0] return (data[0], data[3]) if has_arg_output(name) else data[0]
def index_backward(dim, index, grad, arg_grad): def index_backward(dim, index, grad, arg):
typename = type(grad).__name__.replace('Tensor', '') typename = type(grad).__name__.replace('Tensor', '')
cuda = 'cuda_' if grad.is_cuda else '' cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename)) func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0) output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg_grad) func(dim, output, index, grad, arg)
return output return output
......
...@@ -7,6 +7,6 @@ def scatter_sub_(output, index, input, dim=0): ...@@ -7,6 +7,6 @@ def scatter_sub_(output, index, input, dim=0):
return output.scatter_add_(dim, index, -input) return output.scatter_add_(dim, index, -input)
def scatter_sub(index, input, dim=0, max_index=None, fill_value=0): def scatter_sub(index, input, dim=0, size=None, fill_value=0):
output = gen_output(index, input, dim, max_index, fill_value) output = gen_output(index, input, dim, size, fill_value)
return scatter_sub_(output, index, input, dim) return scatter_sub_(output, index, input, dim)
...@@ -9,8 +9,11 @@ def gen_filled_tensor(input, size, fill_value): ...@@ -9,8 +9,11 @@ def gen_filled_tensor(input, size, fill_value):
return Variable(input.data.new(size).fill_(fill_value)) return Variable(input.data.new(size).fill_(fill_value))
def gen_output(index, input, dim, max_index, fill_value): def gen_output(index, input, dim, dim_size, fill_value):
max_index = index.max() + 1 if max_index is None else max_index if dim_size is None:
dim_size = index.max() + 1
dim_size = dim.size if torch.is_tensor(input) else dim_size.data[0]
size = list(index.size()) size = list(index.size())
size[dim] = max_index if torch.is_tensor(input) else max_index.data[0] size[dim] = dim_size
return gen_filled_tensor(input, torch.Size(size), fill_value) return gen_filled_tensor(input, torch.Size(size), fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment