utils.py 1.58 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn

chenych's avatar
chenych committed
8

chenych's avatar
chenych committed
9
def _sigmoid(x):
chenych's avatar
chenych committed
10
11
12
    y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
    return y

chenych's avatar
chenych committed
13
14

def _gather_feat(feat, ind, mask=None):
chenych's avatar
chenych committed
15
16
    dim = feat.size(2)
    ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
chenych's avatar
chenych committed
17
18
19
20
21
22
23
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat

chenych's avatar
chenych committed
24

chenych's avatar
chenych committed
25
26
27
28
29
30
def _tranpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()
    feat = feat.view(feat.size(0), -1, feat.size(3))
    feat = _gather_feat(feat, ind)
    return feat

chenych's avatar
chenych committed
31

chenych's avatar
chenych committed
32
33
34
35
36
def flip_tensor(x):
    return torch.flip(x, [3])
    # tmp = x.detach().cpu().numpy()[..., ::-1].copy()
    # return torch.from_numpy(tmp).to(x.device)

chenych's avatar
chenych committed
37

chenych's avatar
chenych committed
38
def flip_lr(x, flip_idx):
chenych's avatar
chenych committed
39
40
41
42
43
44
45
    tmp = x.detach().cpu().numpy()[..., ::-1].copy()
    shape = tmp.shape
    for e in flip_idx:
        tmp[:, e[0], ...], tmp[:, e[1], ...] = \
            tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
    return torch.from_numpy(tmp.reshape(shape)).to(x.device)

chenych's avatar
chenych committed
46
47

def flip_lr_off(x, flip_idx):
chenych's avatar
chenych committed
48
49
50
51
52
53
54
55
56
    tmp = x.detach().cpu().numpy()[..., ::-1].copy()
    shape = tmp.shape
    tmp = tmp.reshape(tmp.shape[0], 17, 2,
                      tmp.shape[2], tmp.shape[3])
    tmp[:, :, 0, :, :] *= -1
    for e in flip_idx:
        tmp[:, e[0], ...], tmp[:, e[1], ...] = \
            tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
    return torch.from_numpy(tmp.reshape(shape)).to(x.device)