integrated_conv.py 5.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os

import torch
from typing import Tuple
from torch.utils.cpp_extension import load

VERBOSE = False


def _resolve(name):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)


try:
    import torch_integrated_conv_cpu
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_integrated_conv_cpu')
    torch_integrated_conv_cpu = load(
        name='torch_integrated_conv_cpu',
        sources=[
            _resolve('integrated_conv_cpu.cpp'),
        ],
        verbose=VERBOSE,
    )


try:
    import torch_integrated_conv_cuda
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_integrated_conv_cuda')
    torch_integrated_conv_cuda = None
    if torch.cuda.is_available():
        torch_integrated_conv_cuda = load(
            name='torch_integrated_conv_cuda',
            sources=[
                _resolve('integrated_conv_cuda.cpp'),
                _resolve('integrated_conv_cuda_kernel.cu'),
            ],
            verbose=VERBOSE,
        )



def _integrated_conv_forward_dispather(input: torch.Tensor,
                                       pos_add: torch.Tensor,
                                       pos_mul: torch.Tensor) -> torch.Tensor:
    if input.is_cuda:
        if torch_integrated_conv_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_integrated_conv_cuda.integrated_conv_cuda(
            input.contiguous(), pos_add.contiguous(), pos_mul.contiguous())
    else:
        return torch_integrated_conv_cpu.integrated_conv_cpu(
            input, pos_add, pos_mul)

def _integrated_conv_backward_dispatcher(input: torch.Tensor,
                                         pos_add: torch.Tensor,
                                         pos_mul: torch.Tensor,
                                         grad_output) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if input.is_cuda:
        if torch_integrated_conv_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
            input.contiguous(), pos_add.contiguous(), pos_mul.contiguous()))
    else:
        return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
            input, pos_add, pos_mul))



class IntegratedConvFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, pos_add: torch.Tensor, pos_mul: torch.Tensor) -> torch.Tensor:
        output = _integrated_conv_forward_dispatcher(input, pos_add, pos_mul)
        ctx.save_for_backward(input, pos_add, pos_mul)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (input, pos_add, pos_mul) = ctx.saved_tensors
        grad_input, grad_pos_add, grad_pos_mul = _integrated_conv_backward_dispatcher(
            input, pos_add, pos_mul, grad_output)
        return grad_input, grad_pos_add, grad_pos_mul


def integrated_conv(input, pos_add, pos_mul):
    """Integrated convolution.
    Args:
       input:   The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
               for 2-d convolution, where
               N is the batch size, C is the number of output channels, and H and W are
               the input image's height and width respectively.  The input channels are
               of two types, "src" and "dest" respectively, meaning whether they relate
               to the source or destination image position; all the "src" channels come
               first, then the "dest" channels.
       pos_add:  Positional encoding: the additive part of the convolution kernel.
               This is of shape (C, kW) for 1-d
               convolution or (C, kH, kW) for 2-d convolution,
               where C is the number of channels and kH and kW are the kernel height and
               kernel width.  Kernel height and width must be odd (we assume zero padding
               so the output size is the same as the input size).
       pos_mul:  Positional encoding: the multiplicative part of the convolution kernel.
               This is of shape (C, kW)
               for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
               is the number of channels and kH and kW are the kernel height and
               kernel width.
    Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
               2-d convolution.  In the 2-d case the output will be satisfy:

              output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
                pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])

              where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
              meaning zero-padding (this is done implicitly by the implementation).
              (Technically this is more closely related to cross-correlation than to
              convolution).
    """
    if input.ndim == 3:
        assert pos_add.ndim == 2 and pos_mul.ndim == 2
        # For now we choose to handle only the 2-dimensional case directly.  The
        # 1-dimensional one is treated as a special case of the 2-dimensional one.
        # Actually we could unsqueeze with -2 or -1 here, as the height and width
        # behave the same.
        return integrated_conv(input.unsqueeze(-2),
                               pos_add.unsqueeze(-2), pos_mul.unsqueeze(-2)).squeeze(-2)
    assert input.ndim == 4 and pos_add.ndim == 3 and pos_mul.ndim == 3
    assert input.dim[1] // 2 == pos_add.dim[0] == pos_mul.dim[0]
    return IntegratedConvFunction.apply(input, pos_add, pos_mul)