utils.py 570 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch
from torch.autograd import Variable


rusty1s's avatar
cleaner  
rusty1s committed
5
6
7
8
9
10
11
def gen_filled_tensor(input, size, fill_value):
    if torch.is_tensor(input):
        return input.new(size).fill_(fill_value)
    else:
        return Variable(input.data.new(size).fill_(fill_value))


rusty1s's avatar
rename  
rusty1s committed
12
13
14
15
16
def gen_output(index, input, dim, dim_size, fill_value):
    if dim_size is None:
        dim_size = index.max() + 1
        dim_size = dim.size if torch.is_tensor(input) else dim_size.data[0]

rusty1s's avatar
rusty1s committed
17
    size = list(index.size())
rusty1s's avatar
rename  
rusty1s committed
18
    size[dim] = dim_size
rusty1s's avatar
cleaner  
rusty1s committed
19
    return gen_filled_tensor(input, torch.Size(size), fill_value)