learned_nonlin.py 5.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os

import torch
from typing import Tuple
from torch.utils.cpp_extension import load

VERBOSE = False


def _resolve(name):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)


try:
    import torch_integrated_conv_cpu
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_integrated_conv_cpu')
    torch_integrated_conv_cpu = load(
        name='torch_integrated_conv_cpu',
        sources=[
            _resolve('integrated_conv_cpu.cpp'),
        ],
        verbose=VERBOSE,
    )


try:
Daniel Povey's avatar
Daniel Povey committed
29
        import torch_integrated_conv_cuda
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_integrated_conv_cuda')
    torch_integrated_conv_cuda = None
    if torch.cuda.is_available():
        torch_integrated_conv_cuda = load(
            name='torch_integrated_conv_cuda',
            sources=[
                _resolve('integrated_conv_cuda.cpp'),
                _resolve('integrated_conv_cuda_kernel.cu'),
            ],
            verbose=VERBOSE,
        )



Daniel Povey's avatar
Daniel Povey committed
46
def _integrated_conv_forward_dispatcher(input: torch.Tensor,
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                                       pos_add: torch.Tensor,
                                       pos_mul: torch.Tensor) -> torch.Tensor:
    if input.is_cuda:
        if torch_integrated_conv_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_integrated_conv_cuda.integrated_conv_cuda(
            input.contiguous(), pos_add.contiguous(), pos_mul.contiguous())
    else:
        return torch_integrated_conv_cpu.integrated_conv_cpu(
            input, pos_add, pos_mul)

def _integrated_conv_backward_dispatcher(input: torch.Tensor,
                                         pos_add: torch.Tensor,
                                         pos_mul: torch.Tensor,
                                         grad_output) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if input.is_cuda:
        if torch_integrated_conv_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
Daniel Povey's avatar
Daniel Povey committed
65
        # Actually it's not a hard requirement that these things be contiguous.
66
        return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
Daniel Povey's avatar
Daniel Povey committed
67
68
            input.contiguous(), pos_add.contiguous(), pos_mul.contiguous(),
            grad_output))
69
70
    else:
        return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
Daniel Povey's avatar
Daniel Povey committed
71
            input, pos_add, pos_mul, grad_output))
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130



class IntegratedConvFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, pos_add: torch.Tensor, pos_mul: torch.Tensor) -> torch.Tensor:
        output = _integrated_conv_forward_dispatcher(input, pos_add, pos_mul)
        ctx.save_for_backward(input, pos_add, pos_mul)
        return output

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        (input, pos_add, pos_mul) = ctx.saved_tensors
        grad_input, grad_pos_add, grad_pos_mul = _integrated_conv_backward_dispatcher(
            input, pos_add, pos_mul, grad_output)
        return grad_input, grad_pos_add, grad_pos_mul


def integrated_conv(input, pos_add, pos_mul):
    """Integrated convolution.
    Args:
       input:   The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
               for 2-d convolution, where
               N is the batch size, C is the number of output channels, and H and W are
               the input image's height and width respectively.  The input channels are
               of two types, "src" and "dest" respectively, meaning whether they relate
               to the source or destination image position; all the "src" channels come
               first, then the "dest" channels.
       pos_add:  Positional encoding: the additive part of the convolution kernel.
               This is of shape (C, kW) for 1-d
               convolution or (C, kH, kW) for 2-d convolution,
               where C is the number of channels and kH and kW are the kernel height and
               kernel width.  Kernel height and width must be odd (we assume zero padding
               so the output size is the same as the input size).
       pos_mul:  Positional encoding: the multiplicative part of the convolution kernel.
               This is of shape (C, kW)
               for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
               is the number of channels and kH and kW are the kernel height and
               kernel width.
    Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
               2-d convolution.  In the 2-d case the output will be satisfy:

              output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
                pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])

              where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
              meaning zero-padding (this is done implicitly by the implementation).
              (Technically this is more closely related to cross-correlation than to
              convolution).
    """
    if input.ndim == 3:
        assert pos_add.ndim == 2 and pos_mul.ndim == 2
        # For now we choose to handle only the 2-dimensional case directly.  The
        # 1-dimensional one is treated as a special case of the 2-dimensional one.
        # Actually we could unsqueeze with -2 or -1 here, as the height and width
        # behave the same.
        return integrated_conv(input.unsqueeze(-2),
                               pos_add.unsqueeze(-2), pos_mul.unsqueeze(-2)).squeeze(-2)
    assert input.ndim == 4 and pos_add.ndim == 3 and pos_mul.ndim == 3
Daniel Povey's avatar
Fix..  
Daniel Povey committed
131
    assert input.shape[1] // 2 == pos_add.shape[0] == pos_mul.shape[0]
132
    return IntegratedConvFunction.apply(input, pos_add, pos_mul)