discounted_cumsum.py 3.1 KB
Newer Older
anton's avatar
anton committed
1
2
import os

3
4
5
import torch
from torch.utils.cpp_extension import load

anton's avatar
anton committed
6
7
8
9
10
VERBOSE = False


def _resolve(name):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
11
12


anton's avatar
anton committed
13
14
15
16
17
18
19
20
21
22
23
try:
    import torch_discounted_cumsum_cpu
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_discounted_cumsum_cpu')
    torch_discounted_cumsum_cpu = load(
        name='torch_discounted_cumsum_cpu',
        sources=[
            _resolve('discounted_cumsum_cpu.cpp'),
        ],
        verbose=VERBOSE,
anton's avatar
anton committed
24
25
26
    )


anton's avatar
anton committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
try:
    import torch_discounted_cumsum_cuda
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_discounted_cumsum_cuda')
    torch_discounted_cumsum_cuda = None
    if torch.cuda.is_available():
        torch_discounted_cumsum_cuda = load(
            name='torch_discounted_cumsum_cuda',
            sources=[
                _resolve('discounted_cumsum_cuda.cpp'),
                _resolve('discounted_cumsum_cuda_kernel.cu'),
            ],
            verbose=VERBOSE,
        )


anton's avatar
anton committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def _discounted_cumsum_left_dispatcher(input, gamma):
    if not torch.is_tensor(input):
        raise ValueError('Input must be a torch.Tensor')
    if input.is_cuda:
        if torch_discounted_cumsum_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_discounted_cumsum_cuda.discounted_cumsum_left_cuda(input.contiguous(), gamma)
    else:
        return torch_discounted_cumsum_cpu.discounted_cumsum_left_cpu(input, gamma)


def _discounted_cumsum_right_dispatcher(input, gamma):
    if not torch.is_tensor(input):
        raise ValueError('Input must be a torch.Tensor')
    if input.is_cuda:
        if torch_discounted_cumsum_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_discounted_cumsum_cuda.discounted_cumsum_right_cuda(input.contiguous(), gamma)
    else:
        return torch_discounted_cumsum_cpu.discounted_cumsum_right_cpu(input, gamma)


class DiscountedCumSumLeftFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gamma):
        output = _discounted_cumsum_left_dispatcher(input, gamma)
        ctx.save_for_backward(torch.tensor(gamma))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        gamma = ctx.saved_variables[0].item()
        grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma)
        return grad_input, None


class DiscountedCumSumRightFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gamma):
        output = _discounted_cumsum_right_dispatcher(input, gamma)
        ctx.save_for_backward(torch.tensor(gamma))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        gamma = ctx.saved_variables[0].item()
        grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma)
        return grad_input, None
92
93


anton's avatar
anton committed
94
def discounted_cumsum_left(input, gamma):
anton's avatar
anton committed
95
    return DiscountedCumSumLeftFunction.apply(input, gamma)
anton's avatar
anton committed
96
97


anton's avatar
anton committed
98
def discounted_cumsum_right(input, gamma):
anton's avatar
anton committed
99
    return DiscountedCumSumRightFunction.apply(input, gamma)