ingroup_inds_op.py 632 Bytes
Newer Older
chenshi3's avatar
chenshi3 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch

try:
    from . import ingroup_inds_cuda
    # import ingroup_indices
except ImportError:
    ingroup_indices = None
    print('Can not import ingroup indices')

ingroup_indices = ingroup_inds_cuda

from torch.autograd import Function
class IngroupIndicesFunction(Function):

    @staticmethod
    def forward(ctx, group_inds):

        out_inds = torch.zeros_like(group_inds) - 1

        ingroup_indices.forward(group_inds, out_inds)

        ctx.mark_non_differentiable(out_inds)

        return out_inds

    @staticmethod
    def backward(ctx, g):

        return None

ingroup_inds = IngroupIndicesFunction.apply