_layers.py 22.2 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
34
import abc
import math

Boris Bonev's avatar
Boris Bonev committed
35
36
37
import torch
import torch.nn as nn
import torch.fft
38
from torch.utils.checkpoint import checkpoint
Boris Bonev's avatar
Boris Bonev committed
39

Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics import InverseRealSHT

Boris Bonev's avatar
Boris Bonev committed
42
43

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
44
   
Boris Bonev's avatar
Boris Bonev committed
45
46
47
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
48
       
Boris Bonev's avatar
Boris Bonev committed
49
        # Computes standard normal cumulative distribution function
Boris Bonev's avatar
Boris Bonev committed
50
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
Boris Bonev's avatar
Boris Bonev committed
51

Boris Bonev's avatar
Boris Bonev committed
52
    if (mean < a - 2 * std) or (mean > b + 2 * std):
Boris Bonev's avatar
Boris Bonev committed
53
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2)
Boris Bonev's avatar
Boris Bonev committed
54

Boris Bonev's avatar
Boris Bonev committed
55
56
57
58
59
60
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
Boris Bonev's avatar
Boris Bonev committed
61

Boris Bonev's avatar
Boris Bonev committed
62
63
64
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
Boris Bonev's avatar
Boris Bonev committed
65

Boris Bonev's avatar
Boris Bonev committed
66
67
68
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
Boris Bonev's avatar
Boris Bonev committed
69

Boris Bonev's avatar
Boris Bonev committed
70
        # Transform to proper mean, std
Boris Bonev's avatar
Boris Bonev committed
71
        tensor.mul_(std * math.sqrt(2.0))
Boris Bonev's avatar
Boris Bonev committed
72
        tensor.add_(mean)
Boris Bonev's avatar
Boris Bonev committed
73

Boris Bonev's avatar
Boris Bonev committed
74
75
76
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Boris Bonev's avatar
Boris Bonev committed
77
78


Boris Bonev's avatar
Boris Bonev committed
79
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
Andrea Paris's avatar
Andrea Paris committed
80
    """Fills the input Tensor with values drawn from a truncated
Boris Bonev's avatar
Boris Bonev committed
81
82
83
84
85
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    
    Parameters
    -----------
    tensor: torch.Tensor
        an n-dimensional `torch.Tensor`
    mean: float
        the mean of the normal distribution
    std: float
        the standard deviation of the normal distribution
    a: float
        the minimum cutoff value, by default -2.0
    b: float
        the maximum cutoff value
    Examples
    --------
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105
106
107
    >>> w = torch.empty(3, 5)
    >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


@torch.jit.script
Boris Bonev's avatar
Boris Bonev committed
108
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
apaaris's avatar
apaaris committed
109
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Boris Bonev's avatar
Boris Bonev committed
110
111
112
113
114
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    Parameters
    ----------
    x : torch.Tensor
        Input tensor
    drop_prob : float, optional
        Probability of dropping a path, by default 0.0
    training : bool, optional
        Whether the model is in training mode, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
129
    """
Boris Bonev's avatar
Boris Bonev committed
130
    if drop_prob == 0.0 or not training:
Boris Bonev's avatar
Boris Bonev committed
131
        return x
Boris Bonev's avatar
Boris Bonev committed
132
    keep_prob = 1.0 - drop_prob
Boris Bonev's avatar
Boris Bonev committed
133
134
135
136
137
138
139
140
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2d ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
apaaris's avatar
apaaris committed
141
142
143
144
145
146
147
148
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    
    This module implements stochastic depth regularization by randomly dropping
    entire residual paths during training, which helps with regularization and
    training of very deep networks.
    
    Parameters
149
    ----------
apaaris's avatar
apaaris committed
150
151
152
153
    drop_prob : float, optional
        Probability of dropping a path, by default None
    """
    
Boris Bonev's avatar
Boris Bonev committed
154
155
156
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
Boris Bonev's avatar
Boris Bonev committed
157

Boris Bonev's avatar
Boris Bonev committed
158
    def forward(self, x):
159

Boris Bonev's avatar
Boris Bonev committed
160
        return drop_path(x, self.drop_prob, self.training)
Boris Bonev's avatar
Boris Bonev committed
161

Boris Bonev's avatar
Boris Bonev committed
162
163

class PatchEmbed(nn.Module):
apaaris's avatar
apaaris committed
164
165
166
    """
    Patch embedding layer for vision transformers.
    
apaaris's avatar
apaaris committed
167
168
169
    This module splits input images into patches and projects them to a
    higher dimensional embedding space using convolutional layers.
    
apaaris's avatar
apaaris committed
170
    Parameters
171
    ----------
apaaris's avatar
apaaris committed
172
173
174
175
176
177
178
179
180
181
    img_size : tuple, optional
        Input image size (height, width), by default (224, 224)
    patch_size : tuple, optional
        Patch size (height, width), by default (16, 16)
    in_chans : int, optional
        Number of input channels, by default 3
    embed_dim : int, optional
        Embedding dimension, by default 768
    """
    
Boris Bonev's avatar
Boris Bonev committed
182
183
184
185
186
187
188
189
190
191
192
193
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
        num_patches = self.red_img_size[0] * self.red_img_size[1]
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
        self.proj.weight.is_shared_mp = ["spatial"]
        self.proj.bias.is_shared_mp = ["spatial"]

    def forward(self, x):
194

Boris Bonev's avatar
Boris Bonev committed
195
196
197
198
199
200
201
202
        # gather input
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # new: B, C, H*W
        x = self.proj(x).flatten(2)
        return x


Boris Bonev's avatar
Boris Bonev committed
203
class MLP(nn.Module):
apaaris's avatar
apaaris committed
204
205
206
    """
    Multi-layer perceptron with optional checkpointing.
    
apaaris's avatar
apaaris committed
207
208
209
    This module implements a feed-forward network with two linear layers
    and an activation function, with optional dropout and gradient checkpointing.
    
apaaris's avatar
apaaris committed
210
    Parameters
211
    ----------
apaaris's avatar
apaaris committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    in_features : int
        Number of input features
    hidden_features : int, optional
        Number of hidden features, by default None (same as in_features)
    out_features : int, optional
        Number of output features, by default None (same as in_features)
    act_layer : nn.Module, optional
        Activation layer, by default nn.ReLU
    output_bias : bool, optional
        Whether to use bias in output layer, by default False
    drop_rate : float, optional
        Dropout rate, by default 0.0
    checkpointing : bool, optional
        Whether to use gradient checkpointing, by default False
    gain : float, optional
apaaris's avatar
apaaris committed
227
        Gain factor for weight initialization, by default 1.0
apaaris's avatar
apaaris committed
228
229
    """
    
Boris Bonev's avatar
Boris Bonev committed
230
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
Boris Bonev's avatar
Boris Bonev committed
231
232
233
234
235
        super(MLP, self).__init__()
        self.checkpointing = checkpointing
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

236
        # Fist dense layer
Boris Bonev's avatar
Boris Bonev committed
237
        fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
238
239
        # initialize the weights correctly
        scale = math.sqrt(2.0 / in_features)
Boris Bonev's avatar
Boris Bonev committed
240
        nn.init.normal_(fc1.weight, mean=0.0, std=scale)
241
242
243
244
        if fc1.bias is not None:
            nn.init.constant_(fc1.bias, 0.0)

        # activation
Boris Bonev's avatar
Boris Bonev committed
245
        act = act_layer()
246
247
248
249
250

        # output layer
        fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
        # gain factor for the output determines the scaling of the output init
        scale = math.sqrt(gain / hidden_features)
Boris Bonev's avatar
Boris Bonev committed
251
        nn.init.normal_(fc2.weight, mean=0.0, std=scale)
252
253
254
        if fc2.bias is not None:
            nn.init.constant_(fc2.bias, 0.0)

Boris Bonev's avatar
Boris Bonev committed
255
        if drop_rate > 0.0:
256
            drop = nn.Dropout2d(drop_rate)
Boris Bonev's avatar
Boris Bonev committed
257
258
259
            self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
        else:
            self.fwd = nn.Sequential(fc1, act, fc2)
Boris Bonev's avatar
Boris Bonev committed
260

Boris Bonev's avatar
Boris Bonev committed
261
262
    @torch.jit.ignore
    def checkpoint_forward(self, x):
263

Boris Bonev's avatar
Boris Bonev committed
264
        return checkpoint(self.fwd, x)
Boris Bonev's avatar
Boris Bonev committed
265

Boris Bonev's avatar
Boris Bonev committed
266
    def forward(self, x):
267

Boris Bonev's avatar
Boris Bonev committed
268
269
270
271
272
        if self.checkpointing:
            return self.checkpoint_forward(x)
        else:
            return self.fwd(x)

Boris Bonev's avatar
Boris Bonev committed
273

Boris Bonev's avatar
Boris Bonev committed
274
275
class RealFFT2(nn.Module):
    """
apaaris's avatar
apaaris committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    Helper routine to wrap FFT similarly to the SHT.
    
    This module provides a wrapper around PyTorch's real FFT2D that mimics
    the interface of spherical harmonic transforms for consistency.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    lmax : int, optional
        Maximum spherical harmonic degree, by default None (same as nlat)
    mmax : int, optional
        Maximum spherical harmonic order, by default None (nlon//2 + 1)
Boris Bonev's avatar
Boris Bonev committed
291
    """
apaaris's avatar
apaaris committed
292
    
Boris Bonev's avatar
Boris Bonev committed
293
    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
294
295
296
297
298
299
300
301
        super(RealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
302

Boris Bonev's avatar
Boris Bonev committed
303
        y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
304
        y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
Boris Bonev's avatar
Boris Bonev committed
305
306
        return y

Boris Bonev's avatar
Boris Bonev committed
307

Boris Bonev's avatar
Boris Bonev committed
308
309
class InverseRealFFT2(nn.Module):
    """
apaaris's avatar
apaaris committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    Helper routine to wrap inverse FFT similarly to the SHT.
    
    This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
    the interface of inverse spherical harmonic transforms for consistency.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    lmax : int, optional
        Maximum spherical harmonic degree, by default None (same as nlat)
    mmax : int, optional
        Maximum spherical harmonic order, by default None (nlon//2 + 1)
Boris Bonev's avatar
Boris Bonev committed
325
    """
apaaris's avatar
apaaris committed
326
    
Boris Bonev's avatar
Boris Bonev committed
327
    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
328
329
330
331
332
333
334
335
        super(InverseRealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
336

Boris Bonev's avatar
Boris Bonev committed
337
        return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
338

Boris Bonev's avatar
Boris Bonev committed
339
340
341

class LayerNorm(nn.Module):
    """
apaaris's avatar
apaaris committed
342
343
344
345
346
347
348
    Wrapper class that moves the channel dimension to the end.
    
    This module provides a layer normalization that works with channel-first
    tensors by temporarily transposing the channel dimension to the end,
    applying normalization, and then transposing back.
    
    Parameters
349
    ----------
apaaris's avatar
apaaris committed
350
351
352
353
354
355
356
357
358
359
360
361
    in_channels : int
        Number of input channels
    eps : float, optional
        Epsilon for numerical stability, by default 1e-05
    elementwise_affine : bool, optional
        Whether to use learnable affine parameters, by default True
    bias : bool, optional
        Whether to use bias, by default True
    device : torch.device, optional
        Device to place the module on, by default None
    dtype : torch.dtype, optional
        Data type for the module, by default None
Boris Bonev's avatar
Boris Bonev committed
362
    """
apaaris's avatar
apaaris committed
363
    
Boris Bonev's avatar
Boris Bonev committed
364
365
366
367
368
369
370
371
    def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
        super().__init__()

        self.channel_dim = -3

        self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)

    def forward(self, x):
372

Boris Bonev's avatar
Boris Bonev committed
373
374
375
        return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)


Boris Bonev's avatar
Boris Bonev committed
376
377
class SpectralConvS2(nn.Module):
    """
apaaris's avatar
apaaris committed
378
379
380
    Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
    using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
    domain via the RealFFT2 and InverseRealFFT2 wrappers.
apaaris's avatar
apaaris committed
381
382
    
    Parameters
383
    ----------
apaaris's avatar
apaaris committed
384
    forward_transform : nn.Module
apaaris's avatar
apaaris committed
385
        Forward transform (SHT or FFT)
apaaris's avatar
apaaris committed
386
    inverse_transform : nn.Module
apaaris's avatar
apaaris committed
387
        Inverse transform (ISHT or IFFT)
apaaris's avatar
apaaris committed
388
389
390
391
392
393
394
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    gain : float, optional
        Gain factor for weight initialization, by default 2.0
    operator_type : str, optional
apaaris's avatar
apaaris committed
395
        Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
apaaris's avatar
apaaris committed
396
    lr_scale_exponent : int, optional
apaaris's avatar
apaaris committed
397
        Learning rate scaling exponent, by default 0
apaaris's avatar
apaaris committed
398
399
    bias : bool, optional
        Whether to use bias, by default False
Boris Bonev's avatar
Boris Bonev committed
400
    """
apaaris's avatar
apaaris committed
401
    
Boris Bonev's avatar
Boris Bonev committed
402
    def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
apaaris's avatar
apaaris committed
403
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
404
405
406

        self.forward_transform = forward_transform
        self.inverse_transform = inverse_transform
apaaris's avatar
apaaris committed
407
408
409
410
411
412
413

        self.modes_lat = self.inverse_transform.lmax
        self.modes_lon = self.inverse_transform.mmax

        self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)

        # remember factorization details
Boris Bonev's avatar
Boris Bonev committed
414
415
        self.operator_type = operator_type

apaaris's avatar
apaaris committed
416
417
        assert self.inverse_transform.lmax == self.modes_lat
        assert self.inverse_transform.mmax == self.modes_lon
apaaris's avatar
apaaris committed
418

apaaris's avatar
apaaris committed
419
420
421
422
423
424
425
426
427
428
429
        weight_shape = [out_channels, in_channels]

        if self.operator_type == "diagonal":
            weight_shape += [self.modes_lat, self.modes_lon]
            self.contract_func = "...ilm,oilm->...olm"
        elif self.operator_type == "block-diagonal":
            weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
            self.contract_func = "...ilm,oilnm->...oln"
        elif self.operator_type == "driscoll-healy":
            weight_shape += [self.modes_lat]
            self.contract_func = "...ilm,oil->...olm"
apaaris's avatar
apaaris committed
430
        else:
apaaris's avatar
apaaris committed
431
432
433
434
435
436
437
            raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

        # form weight tensors
        scale = math.sqrt(gain / in_channels)
        self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
        if bias:
            self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
Boris Bonev's avatar
Boris Bonev committed
438
439

    def forward(self, x):
440

apaaris's avatar
apaaris committed
441
442
443
        dtype = x.dtype
        x = x.float()
        residual = x
apaaris's avatar
apaaris committed
444

apaaris's avatar
apaaris committed
445
446
447
448
        with torch.autocast(device_type="cuda", enabled=False):
            x = self.forward_transform(x)
            if self.scale_residual:
                residual = self.inverse_transform(x)
apaaris's avatar
apaaris committed
449

apaaris's avatar
apaaris committed
450
        x = torch.einsum(self.contract_func, x, self.weight)
apaaris's avatar
apaaris committed
451

apaaris's avatar
apaaris committed
452
453
        with torch.autocast(device_type="cuda", enabled=False):
            x = self.inverse_transform(x)
Boris Bonev's avatar
Boris Bonev committed
454

apaaris's avatar
apaaris committed
455
456
457
458
459
        if hasattr(self, "bias"):
            x = x + self.bias
        x = x.type(dtype)

        return x, residual
Boris Bonev's avatar
Boris Bonev committed
460

Boris Bonev's avatar
Boris Bonev committed
461
462
463

class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
    """
apaaris's avatar
apaaris committed
464
465
466
467
    Abstract base class for position embeddings.
    
    This class defines the interface for position embedding modules
    that add positional information to input tensors.
apaaris's avatar
apaaris committed
468
469
    
    Parameters
470
    ----------
apaaris's avatar
apaaris committed
471
472
473
474
475
476
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
477
    """
apaaris's avatar
apaaris committed
478
    
Boris Bonev's avatar
Boris Bonev committed
479
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
480
481
        super().__init__()

Boris Bonev's avatar
Boris Bonev committed
482
483
484
485
        self.img_shape = img_shape
        self.num_chans = num_chans

    def forward(self, x: torch.Tensor):
486

apaaris's avatar
apaaris committed
487
        return x + self.position_embeddings
Boris Bonev's avatar
Boris Bonev committed
488
489
490
491


class SequencePositionEmbedding(PositionEmbedding):
    """
apaaris's avatar
apaaris committed
492
    Standard sequence-based position embedding.
apaaris's avatar
apaaris committed
493
    
apaaris's avatar
apaaris committed
494
495
    This module implements sinusoidal position embeddings similar to those
    used in the original Transformer paper, adapted for 2D spatial data.
apaaris's avatar
apaaris committed
496
497
    
    Parameters
498
    ----------
apaaris's avatar
apaaris committed
499
500
501
502
503
504
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
505
    """
apaaris's avatar
apaaris committed
506
    
Boris Bonev's avatar
Boris Bonev committed
507
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
508
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
Boris Bonev's avatar
Boris Bonev committed
509

apaaris's avatar
apaaris committed
510
511
512
513
514
        with torch.no_grad():
            # alternating custom position embeddings
            pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
            k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
            denom = torch.pow(10000, 2 * k / self.num_chans)
Boris Bonev's avatar
Boris Bonev committed
515

apaaris's avatar
apaaris committed
516
517
518
519
            pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))

        # register tensor
        self.register_buffer("position_embeddings", pos_embed.float())
Boris Bonev's avatar
Boris Bonev committed
520
521
522


class SpectralPositionEmbedding(PositionEmbedding):
apaaris's avatar
apaaris committed
523
524
    """
    Spectral position embeddings for spherical transformers.
apaaris's avatar
apaaris committed
525
    
apaaris's avatar
apaaris committed
526
527
528
    This module creates position embeddings in the spectral domain using
    spherical harmonic functions, which are particularly suitable for
    spherical data processing.
apaaris's avatar
apaaris committed
529
530
531
532
533
534
535
536
537
    
    Parameters
    -----------
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
538
    """
apaaris's avatar
apaaris committed
539
    
Boris Bonev's avatar
Boris Bonev committed
540
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
541
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
Boris Bonev's avatar
Boris Bonev committed
542

apaaris's avatar
apaaris committed
543
544
545
        # compute maximum required frequency and prepare isht
        lmax = math.floor(math.sqrt(self.num_chans)) + 1
        isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
Boris Bonev's avatar
Boris Bonev committed
546

apaaris's avatar
apaaris committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        # fill position embedding
        with torch.no_grad():
            pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)

            for i in range(self.num_chans):
                l = math.floor(math.sqrt(i))
                m = i - l**2 - l

                if m < 0:
                    pos_embed_freq[0, i, l, -m] = 1.0j
                else:
                    pos_embed_freq[0, i, l, m] = 1.0

        # compute spatial position embeddings
        pos_embed = isht(pos_embed_freq)

        # normalization
        pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)

        # register tensor
        self.register_buffer("position_embeddings", pos_embed)
Boris Bonev's avatar
Boris Bonev committed
568
569
570


class LearnablePositionEmbedding(PositionEmbedding):
apaaris's avatar
apaaris committed
571
572
    """
    Learnable position embeddings for spherical transformers.
apaaris's avatar
apaaris committed
573
    
apaaris's avatar
apaaris committed
574
575
    This module provides learnable position embeddings that can be either
    latitude-only or full latitude-longitude embeddings.
apaaris's avatar
apaaris committed
576
577
    
    Parameters
578
    ----------
apaaris's avatar
apaaris committed
579
580
581
582
583
584
585
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
    embed_type : str, optional
apaaris's avatar
apaaris committed
586
        Embedding type ("lat" or "latlon"), by default "lat"
Boris Bonev's avatar
Boris Bonev committed
587
    """
apaaris's avatar
apaaris committed
588
    
Boris Bonev's avatar
Boris Bonev committed
589
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
apaaris's avatar
apaaris committed
590
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
apaaris's avatar
apaaris committed
591

apaaris's avatar
apaaris committed
592
593
594
595
596
597
        if embed_type == "latlon":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
        elif embed_type == "lat":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
        else:
            raise ValueError(f"Unknown learnable position embedding type {embed_type}")
Boris Bonev's avatar
Boris Bonev committed
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

# class SpiralPositionEmbedding(PositionEmbedding):
#     """
#     Returns position embeddings on the torus
#     """

#     def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

#         super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

#         with torch.no_grad():

#             # alternating custom position embeddings
#             lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
#             lats = lats.reshape(-1, 1)
#             lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
#             lons = lons.reshape(1, -1)

#             # channel index
#             k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
#             pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))

#         # register tensor
apaaris's avatar
apaaris committed
621
#         self.register_buffer("position_embeddings", pos_embed.float())