_layers.py 25.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
34
import abc
import math

Boris Bonev's avatar
Boris Bonev committed
35
36
37
import torch
import torch.nn as nn
import torch.fft
38
from torch.utils.checkpoint import checkpoint
Boris Bonev's avatar
Boris Bonev committed
39

Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics import InverseRealSHT

Boris Bonev's avatar
Boris Bonev committed
42
43

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
apaaris's avatar
apaaris committed
44
    """
apaaris's avatar
apaaris committed
45
46
47
48
    Initialize tensor with truncated normal distribution without gradients.
    
    This is a helper function for trunc_normal_ that performs the actual initialization
    without requiring gradients to be tracked.
apaaris's avatar
apaaris committed
49
50
51
52
    
    Parameters
    -----------
    tensor : torch.Tensor
apaaris's avatar
apaaris committed
53
        Tensor to initialize
apaaris's avatar
apaaris committed
54
55
56
57
58
59
60
61
62
63
64
65
    mean : float
        Mean of the normal distribution
    std : float
        Standard deviation of the normal distribution
    a : float
        Lower bound for truncation
    b : float
        Upper bound for truncation
        
    Returns
    -------
    torch.Tensor
apaaris's avatar
apaaris committed
66
        Initialized tensor
apaaris's avatar
apaaris committed
67
    """
Boris Bonev's avatar
Boris Bonev committed
68
69
70
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
apaaris's avatar
apaaris committed
71
72
73
74
75
76
77
78
79
80
81
82
83
        """
        Compute standard normal cumulative distribution function.
        
        Parameters
        -----------
        x : float
            Input value
            
        Returns
        -------
        float
            CDF value
        """
Boris Bonev's avatar
Boris Bonev committed
84
        # Computes standard normal cumulative distribution function
Boris Bonev's avatar
Boris Bonev committed
85
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
Boris Bonev's avatar
Boris Bonev committed
86

Boris Bonev's avatar
Boris Bonev committed
87
    if (mean < a - 2 * std) or (mean > b + 2 * std):
Boris Bonev's avatar
Boris Bonev committed
88
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2)
Boris Bonev's avatar
Boris Bonev committed
89

Boris Bonev's avatar
Boris Bonev committed
90
91
92
93
94
95
    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
Boris Bonev's avatar
Boris Bonev committed
96

Boris Bonev's avatar
Boris Bonev committed
97
98
99
        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)
Boris Bonev's avatar
Boris Bonev committed
100

Boris Bonev's avatar
Boris Bonev committed
101
102
103
        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()
Boris Bonev's avatar
Boris Bonev committed
104

Boris Bonev's avatar
Boris Bonev committed
105
        # Transform to proper mean, std
Boris Bonev's avatar
Boris Bonev committed
106
        tensor.mul_(std * math.sqrt(2.0))
Boris Bonev's avatar
Boris Bonev committed
107
        tensor.add_(mean)
Boris Bonev's avatar
Boris Bonev committed
108

Boris Bonev's avatar
Boris Bonev committed
109
110
111
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
Boris Bonev's avatar
Boris Bonev committed
112
113


Boris Bonev's avatar
Boris Bonev committed
114
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
Boris Bonev's avatar
Boris Bonev committed
115
116
117
118
119
120
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    
    Parameters
    -----------
    tensor: torch.Tensor
        an n-dimensional `torch.Tensor`
    mean: float
        the mean of the normal distribution
    std: float
        the standard deviation of the normal distribution
    a: float
        the minimum cutoff value, by default -2.0
    b: float
        the maximum cutoff value
    Examples
    --------
Boris Bonev's avatar
Boris Bonev committed
136
137
138
139
140
141
142
    >>> w = torch.empty(3, 5)
    >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


@torch.jit.script
Boris Bonev's avatar
Boris Bonev committed
143
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
apaaris's avatar
apaaris committed
144
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Boris Bonev's avatar
Boris Bonev committed
145
146
147
148
149
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    Parameters
    ----------
    x : torch.Tensor
        Input tensor
    drop_prob : float, optional
        Probability of dropping a path, by default 0.0
    training : bool, optional
        Whether the model is in training mode, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
164
    """
Boris Bonev's avatar
Boris Bonev committed
165
    if drop_prob == 0.0 or not training:
Boris Bonev's avatar
Boris Bonev committed
166
        return x
Boris Bonev's avatar
Boris Bonev committed
167
    keep_prob = 1.0 - drop_prob
Boris Bonev's avatar
Boris Bonev committed
168
169
170
171
172
173
174
175
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2d ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
apaaris's avatar
apaaris committed
176
177
178
179
180
181
182
183
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    
    This module implements stochastic depth regularization by randomly dropping
    entire residual paths during training, which helps with regularization and
    training of very deep networks.
    
    Parameters
184
    ----------
apaaris's avatar
apaaris committed
185
186
187
188
    drop_prob : float, optional
        Probability of dropping a path, by default None
    """
    
Boris Bonev's avatar
Boris Bonev committed
189
190
191
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
Boris Bonev's avatar
Boris Bonev committed
192

Boris Bonev's avatar
Boris Bonev committed
193
    def forward(self, x):
apaaris's avatar
apaaris committed
194
        """
apaaris's avatar
apaaris committed
195
        Forward pass with drop path regularization.
apaaris's avatar
apaaris committed
196
197
        
        Parameters
198
    ----------
apaaris's avatar
apaaris committed
199
200
201
202
203
204
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
apaaris's avatar
apaaris committed
205
            Output tensor with potential path dropping
apaaris's avatar
apaaris committed
206
        """
Boris Bonev's avatar
Boris Bonev committed
207
        return drop_path(x, self.drop_prob, self.training)
Boris Bonev's avatar
Boris Bonev committed
208

Boris Bonev's avatar
Boris Bonev committed
209
210

class PatchEmbed(nn.Module):
apaaris's avatar
apaaris committed
211
212
213
    """
    Patch embedding layer for vision transformers.
    
apaaris's avatar
apaaris committed
214
215
216
    This module splits input images into patches and projects them to a
    higher dimensional embedding space using convolutional layers.
    
apaaris's avatar
apaaris committed
217
    Parameters
218
    ----------
apaaris's avatar
apaaris committed
219
220
221
222
223
224
225
226
227
228
    img_size : tuple, optional
        Input image size (height, width), by default (224, 224)
    patch_size : tuple, optional
        Patch size (height, width), by default (16, 16)
    in_chans : int, optional
        Number of input channels, by default 3
    embed_dim : int, optional
        Embedding dimension, by default 768
    """
    
Boris Bonev's avatar
Boris Bonev committed
229
230
231
232
233
234
235
236
237
238
239
240
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.red_img_size = ((img_size[0] // patch_size[0]), (img_size[1] // patch_size[1]))
        num_patches = self.red_img_size[0] * self.red_img_size[1]
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
        self.proj.weight.is_shared_mp = ["spatial"]
        self.proj.bias.is_shared_mp = ["spatial"]

    def forward(self, x):
apaaris's avatar
apaaris committed
241
242
243
244
        """
        Forward pass of patch embedding.
        
        Parameters
245
        ----------
apaaris's avatar
apaaris committed
246
        x : torch.Tensor
apaaris's avatar
apaaris committed
247
            Input tensor of shape (batch_size, channels, height, width)
apaaris's avatar
apaaris committed
248
249
250
251
            
        Returns
        -------
        torch.Tensor
apaaris's avatar
apaaris committed
252
            Patch embeddings of shape (batch_size, embed_dim, num_patches)
apaaris's avatar
apaaris committed
253
        """
Boris Bonev's avatar
Boris Bonev committed
254
255
256
257
258
259
260
261
        # gather input
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        # new: B, C, H*W
        x = self.proj(x).flatten(2)
        return x


Boris Bonev's avatar
Boris Bonev committed
262
class MLP(nn.Module):
apaaris's avatar
apaaris committed
263
264
265
    """
    Multi-layer perceptron with optional checkpointing.
    
apaaris's avatar
apaaris committed
266
267
268
    This module implements a feed-forward network with two linear layers
    and an activation function, with optional dropout and gradient checkpointing.
    
apaaris's avatar
apaaris committed
269
    Parameters
270
    ----------
apaaris's avatar
apaaris committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    in_features : int
        Number of input features
    hidden_features : int, optional
        Number of hidden features, by default None (same as in_features)
    out_features : int, optional
        Number of output features, by default None (same as in_features)
    act_layer : nn.Module, optional
        Activation layer, by default nn.ReLU
    output_bias : bool, optional
        Whether to use bias in output layer, by default False
    drop_rate : float, optional
        Dropout rate, by default 0.0
    checkpointing : bool, optional
        Whether to use gradient checkpointing, by default False
    gain : float, optional
apaaris's avatar
apaaris committed
286
        Gain factor for weight initialization, by default 1.0
apaaris's avatar
apaaris committed
287
288
    """
    
Boris Bonev's avatar
Boris Bonev committed
289
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, output_bias=False, drop_rate=0.0, checkpointing=False, gain=1.0):
Boris Bonev's avatar
Boris Bonev committed
290
291
292
293
294
        super(MLP, self).__init__()
        self.checkpointing = checkpointing
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

295
        # Fist dense layer
Boris Bonev's avatar
Boris Bonev committed
296
        fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True)
297
298
        # initialize the weights correctly
        scale = math.sqrt(2.0 / in_features)
Boris Bonev's avatar
Boris Bonev committed
299
        nn.init.normal_(fc1.weight, mean=0.0, std=scale)
300
301
302
303
        if fc1.bias is not None:
            nn.init.constant_(fc1.bias, 0.0)

        # activation
Boris Bonev's avatar
Boris Bonev committed
304
        act = act_layer()
305
306
307
308
309

        # output layer
        fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias)
        # gain factor for the output determines the scaling of the output init
        scale = math.sqrt(gain / hidden_features)
Boris Bonev's avatar
Boris Bonev committed
310
        nn.init.normal_(fc2.weight, mean=0.0, std=scale)
311
312
313
        if fc2.bias is not None:
            nn.init.constant_(fc2.bias, 0.0)

Boris Bonev's avatar
Boris Bonev committed
314
        if drop_rate > 0.0:
315
            drop = nn.Dropout2d(drop_rate)
Boris Bonev's avatar
Boris Bonev committed
316
317
318
            self.fwd = nn.Sequential(fc1, act, drop, fc2, drop)
        else:
            self.fwd = nn.Sequential(fc1, act, fc2)
Boris Bonev's avatar
Boris Bonev committed
319

Boris Bonev's avatar
Boris Bonev committed
320
321
    @torch.jit.ignore
    def checkpoint_forward(self, x):
apaaris's avatar
apaaris committed
322
323
324
325
        """
        Forward pass with gradient checkpointing.
        
        Parameters
326
        ----------
apaaris's avatar
apaaris committed
327
328
329
330
331
332
333
334
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Output tensor
        """
Boris Bonev's avatar
Boris Bonev committed
335
        return checkpoint(self.fwd, x)
Boris Bonev's avatar
Boris Bonev committed
336

Boris Bonev's avatar
Boris Bonev committed
337
    def forward(self, x):
apaaris's avatar
apaaris committed
338
339
340
341
        """
        Forward pass of the MLP.
        
        Parameters
342
        ----------
apaaris's avatar
apaaris committed
343
344
345
346
347
348
349
350
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Output tensor
        """
Boris Bonev's avatar
Boris Bonev committed
351
352
353
354
355
        if self.checkpointing:
            return self.checkpoint_forward(x)
        else:
            return self.fwd(x)

Boris Bonev's avatar
Boris Bonev committed
356

Boris Bonev's avatar
Boris Bonev committed
357
358
class RealFFT2(nn.Module):
    """
apaaris's avatar
apaaris committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    Helper routine to wrap FFT similarly to the SHT.
    
    This module provides a wrapper around PyTorch's real FFT2D that mimics
    the interface of spherical harmonic transforms for consistency.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    lmax : int, optional
        Maximum spherical harmonic degree, by default None (same as nlat)
    mmax : int, optional
        Maximum spherical harmonic order, by default None (nlon//2 + 1)
Boris Bonev's avatar
Boris Bonev committed
374
    """
apaaris's avatar
apaaris committed
375
    
Boris Bonev's avatar
Boris Bonev committed
376
    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
377
378
379
380
381
382
383
384
        super(RealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
apaaris's avatar
apaaris committed
385
        """
apaaris's avatar
apaaris committed
386
        Forward pass: compute real FFT2D.
apaaris's avatar
apaaris committed
387
388
        
        Parameters
389
        ----------
apaaris's avatar
apaaris committed
390
        x : torch.Tensor
apaaris's avatar
apaaris committed
391
            Input tensor
apaaris's avatar
apaaris committed
392
393
394
395
            
        Returns
        -------
        torch.Tensor
apaaris's avatar
apaaris committed
396
            FFT coefficients
apaaris's avatar
apaaris committed
397
        """
Boris Bonev's avatar
Boris Bonev committed
398
        y = torch.fft.rfft2(x, dim=(-2, -1), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
399
        y = torch.cat((y[..., : math.ceil(self.lmax / 2), : self.mmax], y[..., -math.floor(self.lmax / 2) :, : self.mmax]), dim=-2)
Boris Bonev's avatar
Boris Bonev committed
400
401
        return y

Boris Bonev's avatar
Boris Bonev committed
402

Boris Bonev's avatar
Boris Bonev committed
403
404
class InverseRealFFT2(nn.Module):
    """
apaaris's avatar
apaaris committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    Helper routine to wrap inverse FFT similarly to the SHT.
    
    This module provides a wrapper around PyTorch's inverse real FFT2D that mimics
    the interface of inverse spherical harmonic transforms for consistency.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    lmax : int, optional
        Maximum spherical harmonic degree, by default None (same as nlat)
    mmax : int, optional
        Maximum spherical harmonic order, by default None (nlon//2 + 1)
Boris Bonev's avatar
Boris Bonev committed
420
    """
apaaris's avatar
apaaris committed
421
    
Boris Bonev's avatar
Boris Bonev committed
422
    def __init__(self, nlat, nlon, lmax=None, mmax=None):
Boris Bonev's avatar
Boris Bonev committed
423
424
425
426
427
428
429
430
        super(InverseRealFFT2, self).__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.lmax = lmax or self.nlat
        self.mmax = mmax or self.nlon // 2 + 1

    def forward(self, x):
apaaris's avatar
apaaris committed
431
        """
apaaris's avatar
apaaris committed
432
        Forward pass: compute inverse real FFT2D.
apaaris's avatar
apaaris committed
433
434
        
        Parameters
435
        ----------
apaaris's avatar
apaaris committed
436
        x : torch.Tensor
apaaris's avatar
apaaris committed
437
            Input FFT coefficients
apaaris's avatar
apaaris committed
438
439
440
441
            
        Returns
        -------
        torch.Tensor
apaaris's avatar
apaaris committed
442
            Reconstructed spatial signal
apaaris's avatar
apaaris committed
443
        """
Boris Bonev's avatar
Boris Bonev committed
444
        return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
Boris Bonev's avatar
Boris Bonev committed
445

Boris Bonev's avatar
Boris Bonev committed
446
447
448

class LayerNorm(nn.Module):
    """
apaaris's avatar
apaaris committed
449
450
451
452
453
454
455
    Wrapper class that moves the channel dimension to the end.
    
    This module provides a layer normalization that works with channel-first
    tensors by temporarily transposing the channel dimension to the end,
    applying normalization, and then transposing back.
    
    Parameters
456
    ----------
apaaris's avatar
apaaris committed
457
458
459
460
461
462
463
464
465
466
467
468
    in_channels : int
        Number of input channels
    eps : float, optional
        Epsilon for numerical stability, by default 1e-05
    elementwise_affine : bool, optional
        Whether to use learnable affine parameters, by default True
    bias : bool, optional
        Whether to use bias, by default True
    device : torch.device, optional
        Device to place the module on, by default None
    dtype : torch.dtype, optional
        Data type for the module, by default None
Boris Bonev's avatar
Boris Bonev committed
469
    """
apaaris's avatar
apaaris committed
470
    
Boris Bonev's avatar
Boris Bonev committed
471
472
473
474
475
476
477
478
    def __init__(self, in_channels, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None):
        super().__init__()

        self.channel_dim = -3

        self.norm = nn.LayerNorm(normalized_shape=in_channels, eps=1e-6, elementwise_affine=elementwise_affine, bias=bias, device=device, dtype=dtype)

    def forward(self, x):
apaaris's avatar
apaaris committed
479
        """
apaaris's avatar
apaaris committed
480
        Forward pass with channel dimension handling.
apaaris's avatar
apaaris committed
481
482
        
        Parameters
483
        ----------
apaaris's avatar
apaaris committed
484
        x : torch.Tensor
apaaris's avatar
apaaris committed
485
            Input tensor with channel dimension at -3
apaaris's avatar
apaaris committed
486
487
488
489
            
        Returns
        -------
        torch.Tensor
apaaris's avatar
apaaris committed
490
            Normalized tensor with same shape as input
apaaris's avatar
apaaris committed
491
        """
Boris Bonev's avatar
Boris Bonev committed
492
493
494
        return self.norm(x.transpose(self.channel_dim, -1)).transpose(-1, self.channel_dim)


Boris Bonev's avatar
Boris Bonev committed
495
496
class SpectralConvS2(nn.Module):
    """
apaaris's avatar
apaaris committed
497
498
499
    Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
    using the Spherical Harmonic Transforms in torch-harmonics, but supports convolutions on the periodic
    domain via the RealFFT2 and InverseRealFFT2 wrappers.
apaaris's avatar
apaaris committed
500
501
    
    Parameters
502
    ----------
apaaris's avatar
apaaris committed
503
    forward_transform : nn.Module
apaaris's avatar
apaaris committed
504
        Forward transform (SHT or FFT)
apaaris's avatar
apaaris committed
505
    inverse_transform : nn.Module
apaaris's avatar
apaaris committed
506
        Inverse transform (ISHT or IFFT)
apaaris's avatar
apaaris committed
507
508
509
510
511
512
513
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    gain : float, optional
        Gain factor for weight initialization, by default 2.0
    operator_type : str, optional
apaaris's avatar
apaaris committed
514
        Type of spectral operator ("driscoll-healy", "diagonal", "block-diagonal"), by default "driscoll-healy"
apaaris's avatar
apaaris committed
515
    lr_scale_exponent : int, optional
apaaris's avatar
apaaris committed
516
        Learning rate scaling exponent, by default 0
apaaris's avatar
apaaris committed
517
518
    bias : bool, optional
        Whether to use bias, by default False
Boris Bonev's avatar
Boris Bonev committed
519
    """
apaaris's avatar
apaaris committed
520
    
Boris Bonev's avatar
Boris Bonev committed
521
    def __init__(self, forward_transform, inverse_transform, in_channels, out_channels, gain=2.0, operator_type="driscoll-healy", lr_scale_exponent=0, bias=False):
apaaris's avatar
apaaris committed
522
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
523
524
525

        self.forward_transform = forward_transform
        self.inverse_transform = inverse_transform
apaaris's avatar
apaaris committed
526
527
528
529
530
531
532

        self.modes_lat = self.inverse_transform.lmax
        self.modes_lon = self.inverse_transform.mmax

        self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or (self.forward_transform.nlon != self.inverse_transform.nlon)

        # remember factorization details
Boris Bonev's avatar
Boris Bonev committed
533
534
        self.operator_type = operator_type

apaaris's avatar
apaaris committed
535
536
        assert self.inverse_transform.lmax == self.modes_lat
        assert self.inverse_transform.mmax == self.modes_lon
apaaris's avatar
apaaris committed
537

apaaris's avatar
apaaris committed
538
539
540
541
542
543
544
545
546
547
548
        weight_shape = [out_channels, in_channels]

        if self.operator_type == "diagonal":
            weight_shape += [self.modes_lat, self.modes_lon]
            self.contract_func = "...ilm,oilm->...olm"
        elif self.operator_type == "block-diagonal":
            weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
            self.contract_func = "...ilm,oilnm->...oln"
        elif self.operator_type == "driscoll-healy":
            weight_shape += [self.modes_lat]
            self.contract_func = "...ilm,oil->...olm"
apaaris's avatar
apaaris committed
549
        else:
apaaris's avatar
apaaris committed
550
551
552
553
554
555
556
            raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")

        # form weight tensors
        scale = math.sqrt(gain / in_channels)
        self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
        if bias:
            self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
Boris Bonev's avatar
Boris Bonev committed
557
558

    def forward(self, x):
apaaris's avatar
apaaris committed
559
560
561
562
        """
        Forward pass of spectral convolution.
        
        Parameters
563
        ----------
apaaris's avatar
apaaris committed
564
565
566
567
568
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
apaaris's avatar
apaaris committed
569
570
        tuple
            Tuple containing (output, residual) tensors
apaaris's avatar
apaaris committed
571
        """
apaaris's avatar
apaaris committed
572
573
574
        dtype = x.dtype
        x = x.float()
        residual = x
apaaris's avatar
apaaris committed
575

apaaris's avatar
apaaris committed
576
577
578
579
        with torch.autocast(device_type="cuda", enabled=False):
            x = self.forward_transform(x)
            if self.scale_residual:
                residual = self.inverse_transform(x)
apaaris's avatar
apaaris committed
580

apaaris's avatar
apaaris committed
581
        x = torch.einsum(self.contract_func, x, self.weight)
apaaris's avatar
apaaris committed
582

apaaris's avatar
apaaris committed
583
584
        with torch.autocast(device_type="cuda", enabled=False):
            x = self.inverse_transform(x)
Boris Bonev's avatar
Boris Bonev committed
585

apaaris's avatar
apaaris committed
586
587
588
589
590
        if hasattr(self, "bias"):
            x = x + self.bias
        x = x.type(dtype)

        return x, residual
Boris Bonev's avatar
Boris Bonev committed
591

Boris Bonev's avatar
Boris Bonev committed
592
593
594

class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
    """
apaaris's avatar
apaaris committed
595
596
597
598
    Abstract base class for position embeddings.
    
    This class defines the interface for position embedding modules
    that add positional information to input tensors.
apaaris's avatar
apaaris committed
599
600
    
    Parameters
601
    ----------
apaaris's avatar
apaaris committed
602
603
604
605
606
607
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
608
    """
apaaris's avatar
apaaris committed
609
    
Boris Bonev's avatar
Boris Bonev committed
610
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
611
612
        super().__init__()

Boris Bonev's avatar
Boris Bonev committed
613
614
615
616
        self.img_shape = img_shape
        self.num_chans = num_chans

    def forward(self, x: torch.Tensor):
apaaris's avatar
apaaris committed
617
        """
apaaris's avatar
apaaris committed
618
        Forward pass: add position embeddings to input.
apaaris's avatar
apaaris committed
619
620
621
622
623
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor
apaaris's avatar
apaaris committed
624
625
626
627
628
            
        Returns
        -------
        torch.Tensor
            Input tensor with position embeddings added
apaaris's avatar
apaaris committed
629
        """
apaaris's avatar
apaaris committed
630
        return x + self.position_embeddings
Boris Bonev's avatar
Boris Bonev committed
631
632
633
634


class SequencePositionEmbedding(PositionEmbedding):
    """
apaaris's avatar
apaaris committed
635
    Standard sequence-based position embedding.
apaaris's avatar
apaaris committed
636
    
apaaris's avatar
apaaris committed
637
638
    This module implements sinusoidal position embeddings similar to those
    used in the original Transformer paper, adapted for 2D spatial data.
apaaris's avatar
apaaris committed
639
640
    
    Parameters
641
    ----------
apaaris's avatar
apaaris committed
642
643
644
645
646
647
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
648
    """
apaaris's avatar
apaaris committed
649
    
Boris Bonev's avatar
Boris Bonev committed
650
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
651
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
Boris Bonev's avatar
Boris Bonev committed
652

apaaris's avatar
apaaris committed
653
654
655
656
657
        with torch.no_grad():
            # alternating custom position embeddings
            pos = torch.arange(self.img_shape[0] * self.img_shape[1]).reshape(1, 1, *self.img_shape).repeat(1, self.num_chans, 1, 1)
            k = torch.arange(self.num_chans).reshape(1, self.num_chans, 1, 1)
            denom = torch.pow(10000, 2 * k / self.num_chans)
Boris Bonev's avatar
Boris Bonev committed
658

apaaris's avatar
apaaris committed
659
660
661
662
            pos_embed = torch.where(k % 2 == 0, torch.sin(pos / denom), torch.cos(pos / denom))

        # register tensor
        self.register_buffer("position_embeddings", pos_embed.float())
Boris Bonev's avatar
Boris Bonev committed
663
664
665


class SpectralPositionEmbedding(PositionEmbedding):
apaaris's avatar
apaaris committed
666
667
    """
    Spectral position embeddings for spherical transformers.
apaaris's avatar
apaaris committed
668
    
apaaris's avatar
apaaris committed
669
670
671
    This module creates position embeddings in the spectral domain using
    spherical harmonic functions, which are particularly suitable for
    spherical data processing.
apaaris's avatar
apaaris committed
672
673
674
675
676
677
678
679
680
    
    Parameters
    -----------
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
Boris Bonev's avatar
Boris Bonev committed
681
    """
apaaris's avatar
apaaris committed
682
    
Boris Bonev's avatar
Boris Bonev committed
683
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
apaaris's avatar
apaaris committed
684
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
Boris Bonev's avatar
Boris Bonev committed
685

apaaris's avatar
apaaris committed
686
687
688
        # compute maximum required frequency and prepare isht
        lmax = math.floor(math.sqrt(self.num_chans)) + 1
        isht = InverseRealSHT(nlat=self.img_shape[0], nlon=self.img_shape[1], lmax=lmax, mmax=lmax, grid=grid)
Boris Bonev's avatar
Boris Bonev committed
689

apaaris's avatar
apaaris committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        # fill position embedding
        with torch.no_grad():
            pos_embed_freq = torch.zeros(1, self.num_chans, isht.lmax, isht.mmax, dtype=torch.complex64)

            for i in range(self.num_chans):
                l = math.floor(math.sqrt(i))
                m = i - l**2 - l

                if m < 0:
                    pos_embed_freq[0, i, l, -m] = 1.0j
                else:
                    pos_embed_freq[0, i, l, m] = 1.0

        # compute spatial position embeddings
        pos_embed = isht(pos_embed_freq)

        # normalization
        pos_embed = pos_embed / torch.amax(pos_embed.abs(), dim=(-1, -2), keepdim=True)

        # register tensor
        self.register_buffer("position_embeddings", pos_embed)
Boris Bonev's avatar
Boris Bonev committed
711
712
713


class LearnablePositionEmbedding(PositionEmbedding):
apaaris's avatar
apaaris committed
714
715
    """
    Learnable position embeddings for spherical transformers.
apaaris's avatar
apaaris committed
716
    
apaaris's avatar
apaaris committed
717
718
    This module provides learnable position embeddings that can be either
    latitude-only or full latitude-longitude embeddings.
apaaris's avatar
apaaris committed
719
720
    
    Parameters
721
    ----------
apaaris's avatar
apaaris committed
722
723
724
725
726
727
728
    img_shape : tuple, optional
        Image shape (height, width), by default (480, 960)
    grid : str, optional
        Grid type, by default "equiangular"
    num_chans : int, optional
        Number of channels, by default 1
    embed_type : str, optional
apaaris's avatar
apaaris committed
729
        Embedding type ("lat" or "latlon"), by default "lat"
Boris Bonev's avatar
Boris Bonev committed
730
    """
apaaris's avatar
apaaris committed
731
    
Boris Bonev's avatar
Boris Bonev committed
732
    def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
apaaris's avatar
apaaris committed
733
        super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)
apaaris's avatar
apaaris committed
734

apaaris's avatar
apaaris committed
735
736
737
738
739
740
        if embed_type == "latlon":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], self.img_shape[1]))
        elif embed_type == "lat":
            self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.img_shape[0], 1))
        else:
            raise ValueError(f"Unknown learnable position embedding type {embed_type}")
Boris Bonev's avatar
Boris Bonev committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763

# class SpiralPositionEmbedding(PositionEmbedding):
#     """
#     Returns position embeddings on the torus
#     """

#     def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):

#         super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

#         with torch.no_grad():

#             # alternating custom position embeddings
#             lats, _ = _precompute_latitudes(img_shape[0], grid=grid)
#             lats = lats.reshape(-1, 1)
#             lons = torch.linspace(0, 2 * math.pi, img_shape[1] + 1)[:-1]
#             lons = lons.reshape(1, -1)

#             # channel index
#             k = torch.arange(self.num_chans).reshape(1, -1, 1, 1)
#             pos_embed = torch.where(k % 2 == 0, torch.sin(k * (lons + lats)), torch.cos(k * (lons - lats)))

#         # register tensor
apaaris's avatar
apaaris committed
764
#         self.register_buffer("position_embeddings", pos_embed.float())