learned_nonlin.py 7.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os

import torch
from typing import Tuple
from torch.utils.cpp_extension import load

VERBOSE = False


def _resolve(name):
    return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)


try:
    import torch_learned_nonlin_cpu
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_learned_nonlin_cpu')
    torch_learned_nonlin_cpu = load(
        name='torch_learned_nonlin_cpu',
        sources=[
            _resolve('learned_nonlin_cpu.cpp'),
        ],
        verbose=VERBOSE,
    )


try:
        import torch_learned_nonlin_cuda
except ImportError:
    if VERBOSE:
        print('Falling back to JIT compiling torch_learned_nonlin_cuda')
    torch_learned_nonlin_cuda = None
    if torch.cuda.is_available():
        torch_learned_nonlin_cuda = load(
            name='torch_learned_nonlin_cuda',
            sources=[
                _resolve('learned_nonlin_cuda.cpp'),
                _resolve('learned_nonlin_cuda_kernel.cu'),
            ],
            verbose=VERBOSE,
        )



def _learned_nonlin_forward_dispatcher(input: torch.Tensor,
                                       params: torch.Tensor) -> torch.Tensor:
    if input.is_cuda:
        if torch_learned_nonlin_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return torch_learned_nonlin_cuda.learned_nonlin_cuda(
            input, params.contiguous())
    else:
        return torch_learned_nonlin_cpu.learned_nonlin_cpu(
            input, params)

def _learned_nonlin_backward_dispatcher(input: torch.Tensor,
                                        params: torch.Tensor,
                                        grad_output) -> Tuple[torch.Tensor, torch.Tensor]:
    if input.is_cuda:
        if torch_learned_nonlin_cuda is None:
            raise EnvironmentError(f'Failed to load native CUDA module')
        return tuple(torch_learned_nonlin_cuda.learned_nonlin_backward_cuda(
            input, params,
            grad_output))
    else:
        return tuple(torch_learned_nonlin_cpu.learned_nonlin_backward_cpu(
            input, params, grad_output))



72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def _reshape_as_3dim(x: torch.Tensor, dim: int):
    """
    Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
    combining dimensions and unsqueezing as needed.  For example (writing
    the behavior of this function as
        input_shape, dim -> output_shape,
    it will do:
              (3), 0 -> (1, 3, 1)
        (2, 5, 9), 1 -> (2, 5, 9)
        (2, 5, 9), 2 -> (10, 9, 1)
        (3, 4, 5, 6) -> (12, 5, 6)
    The idea is to normalize the shape so the channel dimension is the middle
    of 3, so the implementation can deal with a fixed layout.

     Args:
        x:  tensor to be reshaped
      dim:  Dimension of x that is to be the middle of 3 dimensions in the result.
            If negative, interpreted as an offset from x.dim.
    """
    if dim < 0:
        dim += input.ndim
    orig_shape = list(x.shape)

    # `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
    # is dimension/axis 1.  We do this not by transposing, but by combining
    # adjacent dims.
    a, b = 1, 1
    for i in range(0, dim):
        a *= orig_shape[i]
    for i in range(dim + 1, len(orig_shape)):
        b *= orig_shape[i]
    new_shape = (a, orig_shape[dim], b)
    return x.reshape(new_shape)  # `reshape` will make a contiguous copy if needed.


107
108
class LearnedNonlinFunction(torch.autograd.Function):
    @staticmethod
109
110
111
112
113
114
115
116
    def forward(ctx, input: torch.Tensor, params: torch.Tensor, dim: int) -> torch.Tensor:
        if dim < 0:
            dim += input.ndim
        assert dim >= 0 and dim < input.ndim
        assert params.ndim == 2 and params.shape[1] % 2 == 1
        assert params.shape[0] == input.shape[dim]

        ctx.dim = dim
117
        ctx.save_for_backward(input, params)
118
119
        output = _learned_nonlin_forward_dispatcher(_reshape_as_3dim(input, dim),
                                                    params)
120
121
122
        return output

    @staticmethod
123
    def backward(ctx, grad_output: torch.Tensor, None) -> Tuple[torch.Tensor, torch.Tensor, None]:
124
        (input, params) = ctx.saved_tensors
125
126
127
128
129
        orig_shape = input.shape
        # We re-do the reshaping in the backward, rather than save the reshaped
        # input, so that if this reshaping results in a copy it is not retained
        # (this saves memory at the expense of a little extra work in such
        # situations).
130
        grad_input, grad_params = _learned_nonlin_backward_dispatcher(
131
132
            _reshape_as_3dim(input, ctx.dim), params, grad_output)
        return grad_input.reshape(input.shape), grad_params, None
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192


def learned_nonlin(input, params, dim):
    """Learned nonlinearity.
    Args:
       input:   The input, to be transformed pointwise; may be of any shape.

       params:  The parameters of the learned nonlinearity.  Interpreted
                as of shape (C, N + 1), where C is the channel and N, which
                must be a power of 2 more than 1, is the number of linear regions in the
                piecewise linear function.  The first element is the log
                of the distance between the discontinuities, and the
                remaining elements are the derivatives of the function
                in the linear pieces.  We can explain what this function
                is as follows:
                     Let the row of `params` for a particular channel be
                interpreted as (l, d0, d1, d2 ... ).  Let K = N/2, and L = exp(l).
                Then the discontinuities in the function are at:
                    L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
                and the values d0, d1 .. are interpreted as the slopes of the
                function in the intervals, respectively:
                    [-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
                and we use these together with the assumption that the
                function's value at x=0 is 0, to compute the function's value.

                In terms of concrete calculations, we do it as follows:
                Firstly, we can get rid of the factor of L by treating the l
                parameter as a scale on the input and output, i.e.:
                  x = input * exp(-l)
                ... do the calculation y = f(xwithout a scale, interpreting the
                discontinuities as being at integer values -K+1, -K+2, ... K+1,
                and then:
                  output = y * = output * exp(l)

                The core computation requires computing the y-values at the
                discontinuities at -K+1, -K+2 and so on.  Each one equals
                the sign of the offset (- for negative K) times the sum
                of the derivatives 'd' for the regions between the current
                points and zero.  If we number these as offsets o0, o1 and
                so on up to N-2, then the formula is:

                     for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
                     for o_n with n >= k, o_N = sum(K..n-1) d_k

                  e.g. if K=3 and  (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:

                  o_0 = -(d1+d2) = -3    # x=-2 maps to y=-3
                  o_1 = -(d2) = -2       # x=-1 maps to y=-2
                  o_2 = () = 0           # x=0 maps to y=0
                  o_3 = (d3) = 2         # x=1 maps to y=2
                  o_4 = (d3 + d4) = 3    # x=2 maps to y=3

        dim:  The dimension of `input` that corresponds to the channel.  It is
              recommended that the channel should not be the fastest-varying
              dimension (the one with stride=1), because this will make
              the data loads and stores be non-coalesced and the kernels
              will be quite slow.

          Return: output, of the same shape as `input`.
    """
193
    return LearnedNonlinFunction.apply(x, params, dim)