utils.py 552 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
from torch.autograd import Variable


rusty1s's avatar
cleaner  
rusty1s committed
5
6
7
8
9
10
11
def gen_filled_tensor(input, size, fill_value):
    if torch.is_tensor(input):
        return input.new(size).fill_(fill_value)
    else:
        return Variable(input.data.new(size).fill_(fill_value))


rusty1s's avatar
rusty1s committed
12
13
14
def gen_output(index, input, dim, max_index, fill_value):
    max_index = index.max() + 1 if max_index is None else max_index
    size = list(index.size())
rusty1s's avatar
cleaner  
rusty1s committed
15
    size[dim] = max_index if torch.is_tensor(input) else max_index.data[0]
rusty1s's avatar
cleaner  
rusty1s committed
16
    return gen_filled_tensor(input, torch.Size(size), fill_value)