layers.py 14.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
import torch.fft
35
from torch.utils.checkpoint import checkpoint
Boris Bonev's avatar
Boris Bonev committed
36
37
38
import math

from torch_harmonics import *
39
40
from .contractions import *
from .activations import *
Boris Bonev's avatar
Boris Bonev committed
41
42
43
44

# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
45
46
# tl.set_backend("pytorch")
# use_opt_einsum("optimal")
Boris Bonev's avatar
Boris Bonev committed
47
48
49
50
51
52
53
54
from tltorch.factorized_tensors.core import FactorizedTensor

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.
Boris Bonev's avatar
Boris Bonev committed
55

Boris Bonev's avatar
Boris Bonev committed
56
57
58
59
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
Boris Bonev's avatar
Boris Bonev committed
60

Boris Bonev's avatar
Boris Bonev committed
61
62
63
64
65
66
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
Boris Bonev's avatar
Boris Bonev committed
67

Boris Bonev's avatar
Boris Bonev committed
68
69
70
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
Boris Bonev's avatar
Boris Bonev committed
71

Boris Bonev's avatar
Boris Bonev committed
72
73
74
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
Boris Bonev's avatar
Boris Bonev committed
75

Boris Bonev's avatar
Boris Bonev committed
76
77
78
        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
Boris Bonev's avatar
Boris Bonev committed
79

Boris Bonev's avatar
Boris Bonev committed
80
81
82
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Boris Bonev's avatar
Boris Bonev committed
83
84


Boris Bonev's avatar
Boris Bonev committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
    tensor: an n-dimensional `torch.Tensor`
    mean: the mean of the normal distribution
    std: the standard deviation of the normal distribution
    a: the minimum cutoff value
    b: the maximum cutoff value
    Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


@torch.jit.script
def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> torch.Tensor:
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1. - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2d ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
Boris Bonev's avatar
Boris Bonev committed
130

Boris Bonev's avatar
Boris Bonev committed
131
    def forward(self, x):
Boris Bonev's avatar
Boris Bonev committed
132
        return drop_path(x, self.drop_prob, self.training)
Boris Bonev's avatar
Boris Bonev committed
133
134
135
136
137
138

class MLP(nn.Module):
    def __init__(self,
                 in_features,
                 hidden_features = None,
                 out_features = None,
139
140
                 act_layer = nn.ReLU,
                 output_bias = False,
Boris Bonev's avatar
Boris Bonev committed
141
                 drop_rate = 0.,
142
143
                 checkpointing = False,
                 gain = 1.0):
Boris Bonev's avatar
Boris Bonev committed
144
145
146
147
148
        super(MLP, self).__init__()
        self.checkpointing = checkpointing
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

149
        # Fist dense layer
Boris Bonev's avatar
Boris Bonev committed
150
        fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
151
152
153
154
155
156
157
        # initialize the weights correctly
        scale = math.sqrt(2.0 / in_features)
        nn.init.normal_(fc1.weight, mean=0., std=scale)
        if fc1.bias is not None:
            nn.init.constant_(fc1.bias, 0.0)

        # activation
Boris Bonev's avatar
Boris Bonev committed
158
        act = act_layer()
159
160
161
162
163
164
165
166
167

        # output layer
        fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
        # gain factor for the output determines the scaling of the output init
        scale = math.sqrt(gain / hidden_features)
        nn.init.normal_(fc2.weight, mean=0., std=scale)
        if fc2.bias is not None:
            nn.init.constant_(fc2.bias, 0.0)

Boris Bonev's avatar
Boris Bonev committed
168
        if drop_rate > 0.:
169
            drop = nn.Dropout2d(drop_rate)
Boris Bonev's avatar
Boris Bonev committed
170
171
172
            self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
        else:
            self.fwd = nn.Sequential(fc1, act, fc2)
Boris Bonev's avatar
Boris Bonev committed
173

Boris Bonev's avatar
Boris Bonev committed
174
175
176
    @torch.jit.ignore
    def checkpoint_forward(self, x):
        return checkpoint(self.fwd, x)
Boris Bonev's avatar
Boris Bonev committed
177

Boris Bonev's avatar
Boris Bonev committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    def forward(self, x):
        if self.checkpointing:
            return self.checkpoint_forward(x)
        else:
            return self.fwd(x)

class RealFFT2(nn.Module):
    """
    Helper routine to wrap FFT similarly to the SHT
    """
    def __init__(self,
                 nlat,
                 nlon,
                 lmax = None,
                 mmax = None):
        super(RealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
        y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
        y = torch.cat((y[..., :math.ceil(self.lmax/2), :self.mmax], y[..., -math.floor(self.lmax/2):, :self.mmax]), dim=-2)
        return y

class InverseRealFFT2(nn.Module):
    """
    Helper routine to wrap FFT similarly to the SHT
    """
    def __init__(self,
                 nlat,
                 nlon,
                 lmax = None,
                 mmax = None):
        super(InverseRealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
        return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
223

Boris Bonev's avatar
Boris Bonev committed
224
225
226
227
228
229
class SpectralConvS2(nn.Module):
    """
    Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
    using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
    domain via the RealFFT2 and InverseRealFFT2 wrappers.
    """
Boris Bonev's avatar
Boris Bonev committed
230

Boris Bonev's avatar
Boris Bonev committed
231
232
233
234
235
    def __init__(self,
                 forward_transform,
                 inverse_transform,
                 in_channels,
                 out_channels,
236
237
                 gain = 2.,
                 operator_type = "driscoll-healy",
Boris Bonev's avatar
Boris Bonev committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
                 lr_scale_exponent = 0,
                 bias = False):
        super(SpectralConvS2, self).__init__()

        self.forward_transform = forward_transform
        self.inverse_transform = inverse_transform

        self.modes_lat = self.inverse_transform.lmax
        self.modes_lon = self.inverse_transform.mmax

        self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
                        or (self.forward_transform.nlon != self.inverse_transform.nlon)

        # remember factorization details
        self.operator_type = operator_type

        assert self.inverse_transform.lmax == self.modes_lat
        assert self.inverse_transform.mmax == self.modes_lon

257
        weight_shape = [out_channels, in_channels]
Boris Bonev's avatar
Boris Bonev committed
258

259
        if self.operator_type == "diagonal":
Boris Bonev's avatar
Boris Bonev committed
260
261
            weight_shape += [self.modes_lat, self.modes_lon]
            from .contractions import contract_diagonal as _contract
262
        elif self.operator_type == "block-diagonal":
Boris Bonev's avatar
Boris Bonev committed
263
264
            weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
            from .contractions import contract_blockdiag as _contract
265
        elif self.operator_type == "driscoll-healy":
Boris Bonev's avatar
Boris Bonev committed
266
267
268
269
270
271
            weight_shape += [self.modes_lat]
            from .contractions import contract_dhconv as _contract
        else:
            raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

        # form weight tensors
272
273
274
        scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
        scale[0] *=  math.sqrt(2)
        self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
Boris Bonev's avatar
Boris Bonev committed
275
276
277

        # get the right contraction function
        self._contract = _contract
Boris Bonev's avatar
Boris Bonev committed
278

Boris Bonev's avatar
Boris Bonev committed
279
        if bias:
280
            self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
Boris Bonev's avatar
Boris Bonev committed
281

Boris Bonev's avatar
Boris Bonev committed
282

Boris Bonev's avatar
Boris Bonev committed
283
284
285
286
287
288
    def forward(self, x):

        dtype = x.dtype
        x = x.float()
        residual = x

Boris Bonev's avatar
Boris Bonev committed
289
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
290
291
292
293
294
295
296
297
298
            x = self.forward_transform(x)
            if self.scale_residual:
                residual = self.inverse_transform(x)


        x = torch.view_as_real(x)
        x = self._contract(x, self.weight)
        x = torch.view_as_complex(x)

Boris Bonev's avatar
Boris Bonev committed
299
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
300
            x = self.inverse_transform(x)
Boris Bonev's avatar
Boris Bonev committed
301

302
        if hasattr(self, "bias"):
Boris Bonev's avatar
Boris Bonev committed
303
304
            x = x + self.bias
        x = x.type(dtype)
Boris Bonev's avatar
Boris Bonev committed
305

Boris Bonev's avatar
Boris Bonev committed
306
307
308
309
310
311
        return x, residual

class FactorizedSpectralConvS2(nn.Module):
    """
    Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
    """
Boris Bonev's avatar
Boris Bonev committed
312

Boris Bonev's avatar
Boris Bonev committed
313
314
315
316
317
    def __init__(self,
                 forward_transform,
                 inverse_transform,
                 in_channels,
                 out_channels,
318
319
                 gain = 2.,
                 operator_type = "driscoll-healy",
Boris Bonev's avatar
Boris Bonev committed
320
321
322
                 rank = 0.2,
                 factorization = None,
                 separable = False,
323
                 implementation = "factorized",
Boris Bonev's avatar
Boris Bonev committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                 decomposition_kwargs=dict(),
                 bias = False):
        super(SpectralConvS2, self).__init__()

        self.forward_transform = forward_transform
        self.inverse_transform = inverse_transform

        self.modes_lat = self.inverse_transform.lmax
        self.modes_lon = self.inverse_transform.mmax

        self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
                        or (self.forward_transform.nlon != self.inverse_transform.nlon)

        # Make sure we are using a Complex Factorized Tensor
        if factorization is None:
339
340
341
            factorization = "Dense" # No factorization
        if not factorization.lower().startswith("complex"):
            factorization = f"Complex{factorization}"
Boris Bonev's avatar
Boris Bonev committed
342
343
344
345
346
347
348
349
350
351

        # remember factorization details
        self.operator_type = operator_type
        self.rank = rank
        self.factorization = factorization
        self.separable = separable

        assert self.inverse_transform.lmax == self.modes_lat
        assert self.inverse_transform.mmax == self.modes_lon

352
        weight_shape = [out_channels]
Boris Bonev's avatar
Boris Bonev committed
353
354

        if not self.separable:
355
            weight_shape += [in_channels]
Boris Bonev's avatar
Boris Bonev committed
356

357
        if self.operator_type == "diagonal":
Boris Bonev's avatar
Boris Bonev committed
358
            weight_shape += [self.modes_lat, self.modes_lon]
359
        elif self.operator_type == "block-diagonal":
Boris Bonev's avatar
Boris Bonev committed
360
            weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
361
        elif self.operator_type == "driscoll-healy":
Boris Bonev's avatar
Boris Bonev committed
362
363
364
365
366
            weight_shape += [self.modes_lat]
        else:
            raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

        # form weight tensors
Boris Bonev's avatar
Boris Bonev committed
367
        self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
Boris Bonev's avatar
Boris Bonev committed
368
                                           fixed_rank_modes=False, **decomposition_kwargs)
Boris Bonev's avatar
Boris Bonev committed
369

Boris Bonev's avatar
Boris Bonev committed
370
        # initialization of weights
371
        scale = math.sqrt(gain / in_channels)
Boris Bonev's avatar
Boris Bonev committed
372
373
        self.weight.normal_(0, scale)

Boris Bonev's avatar
Boris Bonev committed
374
375
        # get the right contraction function
        from .factorizations import get_contract_fun
Boris Bonev's avatar
Boris Bonev committed
376
        self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
Boris Bonev's avatar
Boris Bonev committed
377

Boris Bonev's avatar
Boris Bonev committed
378
        if bias:
379
            self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
Boris Bonev's avatar
Boris Bonev committed
380

Boris Bonev's avatar
Boris Bonev committed
381

Boris Bonev's avatar
Boris Bonev committed
382
383
384
385
386
387
    def forward(self, x):

        dtype = x.dtype
        x = x.float()
        residual = x

Boris Bonev's avatar
Boris Bonev committed
388
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
389
390
391
392
393
394
            x = self.forward_transform(x)
            if self.scale_residual:
                residual = self.inverse_transform(x)

        x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)

Boris Bonev's avatar
Boris Bonev committed
395
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
396
            x = self.inverse_transform(x)
Boris Bonev's avatar
Boris Bonev committed
397

398
        if hasattr(self, "bias"):
Boris Bonev's avatar
Boris Bonev committed
399
400
            x = x + self.bias
        x = x.type(dtype)
Boris Bonev's avatar
Boris Bonev committed
401

Boris Bonev's avatar
Boris Bonev committed
402
        return x, residual