"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5219773a20ab2e14f03f116ba88038b7f486cc8f"
Commit 28f42bdc authored by rusty1s's avatar rusty1s
Browse files

cleaner

parent 416a2603
import torch
from torch.autograd import Variable
from .scatter import scatter from .scatter import scatter
from .utils import gen_output from .utils import gen_filled_tensor, gen_output
def scatter_add_(output, index, input, dim=0): def scatter_add_(output, index, input, dim=0):
...@@ -42,10 +39,7 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1): ...@@ -42,10 +39,7 @@ def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
def scatter_mean_(output, index, input, dim=0): def scatter_mean_(output, index, input, dim=0):
if torch.is_tensor(input): output_count = gen_filled_tensor(output, output.size(), fill_value=0)
output_count = output.new(output.size()).fill_(0)
else:
output_count = Variable(output.data.new(output.size()).fill_(0))
scatter('mean', dim, output, index, input, output_count) scatter('mean', dim, output, index, input, output_count)
output_count[output_count == 0] = 1 output_count[output_count == 0] = 1
output /= output_count output /= output_count
...@@ -58,10 +52,7 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0): ...@@ -58,10 +52,7 @@ def scatter_mean(index, input, dim=0, max_index=None, fill_value=0):
def scatter_max_(output, index, input, dim=0): def scatter_max_(output, index, input, dim=0):
if torch.is_tensor(input): output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
return scatter('max', dim, output, index, input, output_index) return scatter('max', dim, output, index, input, output_index)
...@@ -71,10 +62,7 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0): ...@@ -71,10 +62,7 @@ def scatter_max(index, input, dim=0, max_index=None, fill_value=0):
def scatter_min_(output, index, input, dim=0): def scatter_min_(output, index, input, dim=0):
if torch.is_tensor(input): output_index = gen_filled_tensor(index, output.size(), fill_value=-1)
output_index = index.new(output.size()).fill_(-1)
else:
output_index = Variable(index.data.new(output.size()).fill_(-1))
return scatter('min', dim, output, index, input, output_index) return scatter('min', dim, output, index, input, output_index)
......
...@@ -2,13 +2,20 @@ import torch ...@@ -2,13 +2,20 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
def gen_filled_tensor(input, size, fill_value):
if torch.is_tensor(input):
return input.new(size).fill_(fill_value)
else:
return Variable(input.data.new(size).fill_(fill_value))
def gen_output(index, input, dim, max_index, fill_value): def gen_output(index, input, dim, max_index, fill_value):
max_index = index.max() + 1 if max_index is None else max_index max_index = index.max() + 1 if max_index is None else max_index
size = list(index.size()) size = list(index.size())
if torch.is_tensor(input): if torch.is_tensor(input):
size[dim] = max_index size[dim] = max_index
return input.new(torch.Size(size)).fill_(fill_value)
else: else:
size[dim] = max_index.data[0] size[dim] = max_index.data[0]
return Variable(input.data.new(torch.Size(size)).fill_(fill_value))
return gen_filled_tensor(input, torch.Size(size), fill_value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment