Commit 28f42bdc authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent 416a2603
import torch
from torch.autograd import Variable
from .scatter import scatter
from .utils import gen_output
from .utils import gen_filled_tensor, gen_output
def scatter_add_(output, index, input, dim=0):
......@@ -42,10 +39,7 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def scatter_mean_(output, index, input, dim=0):
if torch.is_tensor(input):
output_count = output.new(output.size()).fill_(0)
else:
output_count = Variable(output.data.new(output.size()).fill_(0))
output_count = gen_filled_tensor(output, output.size(), fill_value=0)
scatter('mean', dim, output, index, input, output_count)
output_count[output_count == 0] = 1
output /= output_count
......@@ -58,10 +52,7 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0):
if torch.is_tensor(input):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('max', dim, output, index, input, output_index)
......@@ -71,10 +62,7 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0):
if torch.is_tensor(input):
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
return scatter('min', dim, output, index, input, output_index)
......
......@@ -2,13 +2,20 @@ import torch
from torch.autograd import Variable
def gen_filled_tensor(input, size, fill_value):
if torch.is_tensor(input):
return input.new(size).fill_(fill_value)
else:
return Variable(input.data.new(size).fill_(fill_value))
def gen_output(index, input, dim, max_index, fill_value):
max_index = index.max() + 1 if max_index is None else max_index
size = list(index.size())
if torch.is_tensor(input):
size[dim] = max_index
return input.new(torch.Size(size)).fill_(fill_value)
else:
size[dim] = max_index.data[0]
return Variable(input.data.new(torch.Size(size)).fill_(fill_value))
return gen_filled_tensor(input, torch.Size(size), fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment