"examples/contrib/mm-imdb/run_mmimdb.py" did not exist on "6b1ff250842f52136d5159bb67a26b50ba01485d"
discounted_cumsum.py 2.52 KB
Newer Older
1
2
3
4
import torch
from torch.utils.cpp_extension import load


anton's avatar
anton committed
5
6
7
8
torch_discounted_cumsum_cpu = load(
    name='torch_discounted_cumsum_cpu',
    sources=['discounted_cumsum_cpu.cpp'],
    # verbose=True,
9
10
)

anton's avatar
anton committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
torch_discounted_cumsum_cuda = None
if torch.cuda.is_available():
    torch_discounted_cumsum_cuda = load(
        name='torch_discounted_cumsum_cuda',
        sources=['discounted_cumsum_cuda.cpp', 'discounted_cumsum_cuda_kernel.cu'],
        verbose=True,
    )


def _discounted_cumsum_left_dispatcher(input, gamma):
    if not torch.is_tensor(input):
        raise ValueError('Input must be a torch.Tensor')
    if input.is_cuda:
        if torch_discounted_cumsum_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_discounted_cumsum_cuda.discounted_cumsum_left_cuda(input.contiguous(), gamma)
    else:
        return torch_discounted_cumsum_cpu.discounted_cumsum_left_cpu(input, gamma)


def _discounted_cumsum_right_dispatcher(input, gamma):
    if not torch.is_tensor(input):
        raise ValueError('Input must be a torch.Tensor')
    if input.is_cuda:
        if torch_discounted_cumsum_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_discounted_cumsum_cuda.discounted_cumsum_right_cuda(input.contiguous(), gamma)
    else:
        return torch_discounted_cumsum_cpu.discounted_cumsum_right_cpu(input, gamma)


class DiscountedCumSumLeftFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gamma):
        output = _discounted_cumsum_left_dispatcher(input, gamma)
        ctx.save_for_backward(torch.tensor(gamma))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        gamma = ctx.saved_variables[0].item()
        grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma)
        return grad_input, None


class DiscountedCumSumRightFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, gamma):
        output = _discounted_cumsum_right_dispatcher(input, gamma)
        ctx.save_for_backward(torch.tensor(gamma))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        gamma = ctx.saved_variables[0].item()
        grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma)
        return grad_input, None
68
69


anton's avatar
anton committed
70
def discounted_cumsum_left(input, gamma):
anton's avatar
anton committed
71
    return DiscountedCumSumLeftFunction.apply(input, gamma)
anton's avatar
anton committed
72
73


anton's avatar
anton committed
74
def discounted_cumsum_right(input, gamma):
anton's avatar
anton committed
75
    return DiscountedCumSumRightFunction.apply(input, gamma)