_layers.py 15.8 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
34
import abc
import math

Boris Bonev's avatar
Boris Bonev committed
35
36
37
import torch
import torch.nn as nn
import torch.fft
38
from torch.utils.checkpoint import checkpoint
Boris Bonev's avatar
Boris Bonev committed
39

Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics import InverseRealSHT

Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
46
47

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
Boris Bonev's avatar
Boris Bonev committed
48
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
Boris Bonev's avatar
Boris Bonev committed
49

Boris Bonev's avatar
Boris Bonev committed
50
    if (mean < a - 2 * std) or (mean > b + 2 * std):
Boris Bonev's avatar
Boris Bonev committed
51
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2)
Boris Bonev's avatar
Boris Bonev committed
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
Boris Bonev's avatar
Boris Bonev committed
59

Boris Bonev's avatar
Boris Bonev committed
60
61
62
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
Boris Bonev's avatar
Boris Bonev committed
63

Boris Bonev's avatar
Boris Bonev committed
64
65
66
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
Boris Bonev's avatar
Boris Bonev committed
67

Boris Bonev's avatar
Boris Bonev committed
68
        # Transform to proper mean, std
Boris Bonev's avatar
Boris Bonev committed
69
        tensor.mul_(std * math.sqrt(2.0))
Boris Bonev's avatar
Boris Bonev committed
70
        tensor.add_(mean)
Boris Bonev's avatar
Boris Bonev committed
71

Boris Bonev's avatar
Boris Bonev committed
72
73
74
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Boris Bonev's avatar
Boris Bonev committed
75
76


Boris Bonev's avatar
Boris Bonev committed
77
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
Boris Bonev's avatar
Boris Bonev committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
    tensor: an n-dimensional `torch.Tensor`
    mean: the mean of the normal distribution
    std: the standard deviation of the normal distribution
    a: the minimum cutoff value
    b: the maximum cutoff value
    Examples:
    >>> w = torch.empty(3, 5)
    >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


@torch.jit.script
Boris Bonev's avatar
Boris Bonev committed
98
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
Boris Bonev's avatar
Boris Bonev committed
106
    if drop_prob == 0.0 or not training:
Boris Bonev's avatar
Boris Bonev committed
107
        return x
Boris Bonev's avatar
Boris Bonev committed
108
    keep_prob = 1.0 - drop_prob
Boris Bonev's avatar
Boris Bonev committed
109
110
111
112
113
114
115
116
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2d ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
117
118
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

Boris Bonev's avatar
Boris Bonev committed
119
120
121
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
Boris Bonev's avatar
Boris Bonev committed
122

Boris Bonev's avatar
Boris Bonev committed
123
    def forward(self, x):
Boris Bonev's avatar
Boris Bonev committed
124
        return drop_path(x, self.drop_prob, self.training)
Boris Bonev's avatar
Boris Bonev committed
125

Boris Bonev's avatar
Boris Bonev committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

class PatchEmbed(nn.Module):
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
        num_patches = self.red_img_size[0] * self.red_img_size[1]
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
        self.proj.weight.is_shared_mp = ["spatial"]
        self.proj.bias.is_shared_mp = ["spatial"]

    def forward(self, x):
        # gather input
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # new: B, C, H*W
        x = self.proj(x).flatten(2)
        return x


Boris Bonev's avatar
Boris Bonev committed
148
class MLP(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
149
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153
154
        super(MLP, self).__init__()
        self.checkpointing = checkpointing
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

155
        # Fist dense layer
Boris Bonev's avatar
Boris Bonev committed
156
        fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
157
158
        # initialize the weights correctly
        scale = math.sqrt(2.0 / in_features)
Boris Bonev's avatar
Boris Bonev committed
159
        nn.init.normal_(fc1.weight, mean=0.0, std=scale)
160
161
162
163
        if fc1.bias is not None:
            nn.init.constant_(fc1.bias, 0.0)

        # activation
Boris Bonev's avatar
Boris Bonev committed
164
        act = act_layer()
165
166
167
168
169

        # output layer
        fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
        # gain factor for the output determines the scaling of the output init
        scale = math.sqrt(gain / hidden_features)
Boris Bonev's avatar
Boris Bonev committed
170
        nn.init.normal_(fc2.weight, mean=0.0, std=scale)
171
172
173
        if fc2.bias is not None:
            nn.init.constant_(fc2.bias, 0.0)

Boris Bonev's avatar
Boris Bonev committed
174
        if drop_rate > 0.0:
175
            drop = nn.Dropout2d(drop_rate)
Boris Bonev's avatar
Boris Bonev committed
176
177
178
            self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
        else:
            self.fwd = nn.Sequential(fc1, act, fc2)
Boris Bonev's avatar
Boris Bonev committed
179

Boris Bonev's avatar
Boris Bonev committed
180
181
182
    @torch.jit.ignore
    def checkpoint_forward(self, x):
        return checkpoint(self.fwd, x)
Boris Bonev's avatar
Boris Bonev committed
183

Boris Bonev's avatar
Boris Bonev committed
184
185
186
187
188
189
    def forward(self, x):
        if self.checkpointing:
            return self.checkpoint_forward(x)
        else:
            return self.fwd(x)

Boris Bonev's avatar
Boris Bonev committed
190

Boris Bonev's avatar
Boris Bonev committed
191
192
193
194
class RealFFT2(nn.Module):
    """
    Helper routine to wrap FFT similarly to the SHT
    """
Boris Bonev's avatar
Boris Bonev committed
195
196

    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
197
198
199
200
201
202
203
204
205
        super(RealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
        y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
206
        y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
Boris Bonev's avatar
Boris Bonev committed
207
208
        return y

Boris Bonev's avatar
Boris Bonev committed
209

Boris Bonev's avatar
Boris Bonev committed
210
211
212
213
class InverseRealFFT2(nn.Module):
    """
    Helper routine to wrap FFT similarly to the SHT
    """
Boris Bonev's avatar
Boris Bonev committed
214
215

    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
216
217
218
219
220
221
222
223
224
        super(InverseRealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
        return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
225

Boris Bonev's avatar
Boris Bonev committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

class LayerNorm(nn.Module):
    """
    Wrapper class that moves the channel dimension to the end
    """

    def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
        super().__init__()

        self.channel_dim = -3

        self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)

    def forward(self, x):

        return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)


Boris Bonev's avatar
Boris Bonev committed
244
245
246
247
248
249
class SpectralConvS2(nn.Module):
    """
    Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
    using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
    domain via the RealFFT2 and InverseRealFFT2 wrappers.
    """
Boris Bonev's avatar
Boris Bonev committed
250

Boris Bonev's avatar
Boris Bonev committed
251
    def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
252
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
253
254
255
256
257
258
259

        self.forward_transform = forward_transform
        self.inverse_transform = inverse_transform

        self.modes_lat = self.inverse_transform.lmax
        self.modes_lon = self.inverse_transform.mmax

Boris Bonev's avatar
Boris Bonev committed
260
        self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)
Boris Bonev's avatar
Boris Bonev committed
261
262
263
264
265
266
267

        # remember factorization details
        self.operator_type = operator_type

        assert self.inverse_transform.lmax == self.modes_lat
        assert self.inverse_transform.mmax == self.modes_lon

268
        weight_shape = [out_channels, in_channels]
Boris Bonev's avatar
Boris Bonev committed
269

270
        if self.operator_type == "diagonal":
Boris Bonev's avatar
Boris Bonev committed
271
            weight_shape += [self.modes_lat, self.modes_lon]
272
            self.contract_func = "...ilm,oilm->...olm"
273
        elif self.operator_type == "block-diagonal":
Boris Bonev's avatar
Boris Bonev committed
274
            weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
275
            self.contract_func = "...ilm,oilnm->...oln"
276
        elif self.operator_type == "driscoll-healy":
Boris Bonev's avatar
Boris Bonev committed
277
            weight_shape += [self.modes_lat]
278
            self.contract_func = "...ilm,oil->...olm"
Boris Bonev's avatar
Boris Bonev committed
279
280
281
282
        else:
            raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

        # form weight tensors
283
        scale = math.sqrt(gain / in_channels)
284
        self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
Boris Bonev's avatar
Boris Bonev committed
285
        if bias:
286
            self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
Boris Bonev's avatar
Boris Bonev committed
287
288
289
290
291
292
293

    def forward(self, x):

        dtype = x.dtype
        x = x.float()
        residual = x

Boris Bonev's avatar
Boris Bonev committed
294
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
295
296
297
298
            x = self.forward_transform(x)
            if self.scale_residual:
                residual = self.inverse_transform(x)

299
        x = torch.einsum(self.contract_func, x, self.weight)
Boris Bonev's avatar
Boris Bonev committed
300

Boris Bonev's avatar
Boris Bonev committed
301
        with torch.autocast(device_type="cuda", enabled=False):
Boris Bonev's avatar
Boris Bonev committed
302
            x = self.inverse_transform(x)
Boris Bonev's avatar
Boris Bonev committed
303

304
        if hasattr(self, "bias"):
Boris Bonev's avatar
Boris Bonev committed
305
306
            x = x + self.bias
        x = x.type(dtype)
Boris Bonev's avatar
Boris Bonev committed
307

Boris Bonev's avatar
Boris Bonev committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        return x, residual

class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
    """
    Returns standard sequence based position embedding
    """

    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

        super().__init__()

        self.img_shape = img_shape
        self.num_chans = num_chans

    def forward(self, x: torch.Tensor):

        return x + self.position_embeddings

class SequencePositionEmbedding(PositionEmbedding):
    """
    Returns standard sequence based position embedding
    """

    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

        with torch.no_grad():

            # alternating custom position embeddings
            pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
            k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
            denom = torch.pow(10000, 2 * k / self.num_chans)

            pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))

        # register tensor
        self.register_buffer("position_embeddings", pos_embed.float())

class SpectralPositionEmbedding(PositionEmbedding):
    """
    Returns position embeddings for the spherical transformer
    """

    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

        # compute maximum required frequency and prepare isht
        lmax = math.floor(math.sqrt(self.num_chans)) + 1
        isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)

        # fill position embedding
        with torch.no_grad():
            pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)

            for i in range(self.num_chans):
                l = math.floor(math.sqrt(i))
                m = i - l**2 - l

                if m < 0:
                    pos_embed_freq[0, i, l, -m] = 1.0j
                else:
                    pos_embed_freq[0, i, l, m] = 1.0

        # compute spatial position embeddings
        pos_embed = isht(pos_embed_freq)

        # normalization
        pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)

        # register tensor
        self.register_buffer("position_embeddings", pos_embed)


class LearnablePositionEmbedding(PositionEmbedding):
    """
    Returns position embeddings for the spherical transformer
    """

    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):

        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

        if embed_type == "latlon":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
        elif embed_type == "lat":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
        else:
            raise ValueError(f"Unknown learnable position embedding type {embed_type}")

# class SpiralPositionEmbedding(PositionEmbedding):
#     """
#     Returns position embeddings on the torus
#     """

#     def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

#         super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

#         with torch.no_grad():

#             # alternating custom position embeddings
#             lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
#             lats = lats.reshape(-1, 1)
#             lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
#             lons = lons.reshape(1, -1)

#             # channel index
#             k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
#             pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))

#         # register tensor
#         self.register_buffer("position_embeddings", pos_embed.float())