lsno.py 22.8 KB
Newer Older
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
import math

34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
40
from torch_harmonics import ResampleS2
41

Boris Bonev's avatar
Boris Bonev committed
42
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
43
44
45

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
46
47
48
49
50
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
    theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}

    return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
51
52

class DiscreteContinuousEncoder(nn.Module):
apaaris's avatar
apaaris committed
53
54
55
56
57
58
59
    r"""
    Discrete-continuous encoder for spherical neural operators.
    
    This module performs downsampling using discrete-continuous convolutions on the sphere,
    reducing the spatial resolution while maintaining the spectral properties of the data.
    
    Parameters
60
    ----------
apaaris's avatar
apaaris committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (721, 1440)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (480, 960)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    """
82
83
    def __init__(
        self,
84
        in_shape=(721, 1440),
85
86
87
88
89
        out_shape=(480, 960),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
90
91
        kernel_shape=(3, 3),
        basis_type="morlet",
92
93
94
95
96
97
98
99
100
        groups=1,
        bias=False,
    ):
        super().__init__()

        # set up local convolution
        self.conv = DiscreteContinuousConvS2(
            inp_chans,
            out_chans,
101
            in_shape=in_shape,
102
103
            out_shape=out_shape,
            kernel_shape=kernel_shape,
104
            basis_type=basis_type,
105
106
107
108
            grid_in=grid_in,
            grid_out=grid_out,
            groups=groups,
            bias=bias,
Boris Bonev's avatar
Boris Bonev committed
109
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
110
111
112
        )

    def forward(self, x):
apaaris's avatar
apaaris committed
113
114
115
116
        """
        Forward pass of the discrete-continuous encoder.
        
        Parameters
117
        ----------
apaaris's avatar
apaaris committed
118
119
120
121
122
123
124
125
        x : torch.Tensor
            Input tensor with shape (batch, channels, nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Encoded tensor with reduced spatial resolution
        """
126
127
128
129
130
131
132
133
134
135
136
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
            x = self.conv(x)
            x = x.to(dtype=dtype)

        return x


class DiscreteContinuousDecoder(nn.Module):
apaaris's avatar
apaaris committed
137
138
139
140
141
142
143
    r"""
    Discrete-continuous decoder for spherical neural operators.
    
    This module performs upsampling using either spherical harmonic transforms or resampling,
    followed by discrete-continuous convolutions to restore spatial resolution.
    
    Parameters
144
    ----------
apaaris's avatar
apaaris committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (480, 960)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (721, 1440)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    upsample_sht : bool, optional
        Whether to use SHT for upsampling, by default False
    """
168
169
    def __init__(
        self,
170
        in_shape=(480, 960),
171
172
173
174
175
        out_shape=(721, 1440),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
176
177
        kernel_shape=(3, 3),
        basis_type="morlet",
178
179
        groups=1,
        bias=False,
Boris Bonev's avatar
Boris Bonev committed
180
        upsample_sht=False,
181
182
183
    ):
        super().__init__()

184
185
186
187
188
189
190
        # set up upsampling
        if upsample_sht:
            self.sht = RealSHT(*in_shape, grid=grid_in).float()
            self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
            self.upsample = nn.Sequential(self.sht, self.isht)
        else:
            self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
191
192

        # set up DISCO convolution
193
        self.conv = DiscreteContinuousConvS2(
194
195
196
197
198
            inp_chans,
            out_chans,
            in_shape=out_shape,
            out_shape=out_shape,
            kernel_shape=kernel_shape,
199
            basis_type=basis_type,
200
201
202
203
            grid_in=grid_out,
            grid_out=grid_out,
            groups=groups,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
204
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
205
206
207
        )

    def forward(self, x):
apaaris's avatar
apaaris committed
208
209
210
211
        """
        Forward pass of the discrete-continuous decoder.
        
        Parameters
212
        ----------
apaaris's avatar
apaaris committed
213
214
215
216
217
218
219
220
        x : torch.Tensor
            Input tensor with shape (batch, channels, nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Decoded tensor with restored spatial resolution
        """
221
222
223
224
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
225
            x = self.upsample(x)
226
            x = self.conv(x)
227
228
229
230
231
232
233
234
            x = x.to(dtype=dtype)

        return x


class SphericalNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    Parameters
    ----------
    forward_transform : torch.nn.Module
        Forward transform to use for the block
    inverse_transform : torch.nn.Module
        Inverse transform to use for the block
    input_dim : int
        Input dimension
    output_dim : int
        Output dimension
    conv_type : str, optional
        Type of convolution to use, by default "local"
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path : float, optional
        Drop path rate, by default 0.0
    act_layer : torch.nn.Module, optional
        Activation function to use, by default nn.GELU
    norm_layer : str, optional
        Type of normalization to use, by default "none"
    inner_skip : str, optional
        Type of inner skip connection to use, by default "none"
    outer_skip : str, optional
        Type of outer skip connection to use, by default "identity"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    disco_kernel_shape : tuple, optional
        Kernel shape for discrete-continuous convolution, by default (3, 3)
    disco_basis_type : str, optional
        Filter basis type for discrete-continuous convolution, by default "morlet"
    bias : bool, optional
        Whether to use bias, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
275
276
277
278
279
280
281
282
283
284
285
286
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        conv_type="local",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
287
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
288
        norm_layer="none",
289
290
        inner_skip="none",
        outer_skip="identity",
291
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
292
293
294
        disco_kernel_shape=(3, 3),
        disco_basis_type="morlet",
        bias=False,
295
296
297
298
299
300
301
302
303
304
305
306
307
    ):
        super().__init__()

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        # convolution layer
        if conv_type == "local":
Boris Bonev's avatar
Boris Bonev committed
308
            theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
309
310
311
312
313
314
            self.local_conv = DiscreteContinuousConvS2(
                input_dim,
                output_dim,
                in_shape=(forward_transform.nlat, forward_transform.nlon),
                out_shape=(inverse_transform.nlat, inverse_transform.nlon),
                kernel_shape=disco_kernel_shape,
315
                basis_type=disco_basis_type,
316
317
                grid_in=forward_transform.grid,
                grid_out=inverse_transform.grid,
Boris Bonev's avatar
Boris Bonev committed
318
319
                bias=bias,
                theta_cutoff=theta_cutoff,
320
321
            )
        elif conv_type == "global":
Boris Bonev's avatar
Boris Bonev committed
322
            self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        else:
            raise ValueError(f"Unknown convolution type {conv_type}")

        if inner_skip == "linear":
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif inner_skip == "identity":
            assert input_dim == output_dim
            self.inner_skip = nn.Identity()
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")

Boris Bonev's avatar
Boris Bonev committed
337
338
339
340
341
342
343
344
345
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        if use_mlp == True:
            mlp_hidden_dim = int(output_dim * mlp_ratio)
            self.mlp = MLP(
                in_features=output_dim,
                out_features=input_dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop_rate=drop_rate,
                checkpointing=False,
                gain=gain_factor,
            )

        if outer_skip == "linear":
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif outer_skip == "identity":
            assert input_dim == output_dim
            self.outer_skip = nn.Identity()
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")

    def forward(self, x):

        residual = x

        if hasattr(self, "global_conv"):
            x, _ = self.global_conv(x)
        elif hasattr(self, "local_conv"):
            x = self.local_conv(x)

Boris Bonev's avatar
Boris Bonev committed
386
        x = self.norm(x)
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

        if hasattr(self, "inner_skip"):
            x = x + self.inner_skip(residual)

        if hasattr(self, "mlp"):
            x = self.mlp(x)

        x = self.drop_path(x)

        if hasattr(self, "outer_skip"):
            x = x + self.outer_skip(residual)

        return x


Boris Bonev's avatar
Boris Bonev committed
402
class LocalSphericalNeuralOperator(nn.Module):
apaaris's avatar
apaaris committed
403
    r"""
404
405
406
407
    LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
    operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
    Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
    as well as in the encoder and decoders.
apaaris's avatar
apaaris committed
408
    
409
    Parameters
410
    ----------
apaaris's avatar
apaaris committed
411
412
413
414
415
416
    img_size : tuple, optional
        Input image size (nlat, nlon), by default (128, 256)
    grid : str, optional
        Grid type for input/output, by default "equiangular"
    grid_internal : str, optional
        Grid type for internal processing, by default "legendre-gauss"
417
    scale_factor : int, optional
apaaris's avatar
apaaris committed
418
        Scale factor for resolution changes, by default 3
419
420
421
422
423
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
apaaris's avatar
apaaris committed
424
        Embedding dimension, by default 256
425
    num_layers : int, optional
apaaris's avatar
apaaris committed
426
        Number of layers, by default 4
427
    activation_function : str, optional
apaaris's avatar
apaaris committed
428
429
430
431
432
433
434
435
436
437
438
        Activation function name, by default "gelu"
    kernel_shape : tuple, optional
        Kernel shape for convolutions, by default (3, 3)
    encoder_kernel_shape : tuple, optional
        Kernel shape for encoder, by default (3, 3)
    filter_basis_type : str, optional
        Filter basis type, by default "morlet"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
439
440
441
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
apaaris's avatar
apaaris committed
442
        Drop path rate, by default 0.0
443
444
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Boris Bonev's avatar
Boris Bonev committed
445
    sfno_block_frequency : int, optional
apaaris's avatar
apaaris committed
446
        Frequency of SFNO blocks, by default 2
447
    hard_thresholding_fraction : float, optional
apaaris's avatar
apaaris committed
448
449
450
451
452
        Hard thresholding fraction, by default 1.0
    residual_prediction : bool, optional
        Whether to use residual prediction, by default False
    pos_embed : str, optional
        Position embedding type, by default "none"
453
454
    upsample_sht : bool, optional
        Use SHT upsampling if true, else linear interpolation
Boris Bonev's avatar
Boris Bonev committed
455
456
    bias : bool, optional
        Whether to use a bias, by default False
457

458
    Example
459
    ----------
460
    >>> model = LocalSphericalNeuralOperator(
461
462
463
464
465
466
467
468
469
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=4,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
470
471

    References
472
    ----------
473
474
475
476
477
478
479
480
    .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
        "Neural Operators with Localized Integral and Differential Kernels" (2024).
        ICML 2024, https://arxiv.org/pdf/2402.16845.

    .. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.

481
482
483
484
485
486
    """

    def __init__(
        self,
        img_size=(128, 256),
        grid="equiangular",
487
        grid_internal="legendre-gauss",
488
        scale_factor=3,
489
490
491
492
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
493
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
494
495
496
        kernel_shape=(3, 3),
        encoder_kernel_shape=(3, 3),
        filter_basis_type="morlet",
497
498
499
500
501
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
Boris Bonev's avatar
Boris Bonev committed
502
        sfno_block_frequency=2,
503
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
504
505
        residual_prediction=False,
        pos_embed="none",
506
        upsample_sht=False,
Boris Bonev's avatar
Boris Bonev committed
507
        bias=False,
508
509
510
511
512
    ):
        super().__init__()

        self.img_size = img_size
        self.grid = grid
513
        self.grid_internal = grid_internal
514
515
516
517
518
519
520
521
522
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.encoder_kernel_shape = encoder_kernel_shape
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
Boris Bonev's avatar
Boris Bonev committed
523
        self.residual_prediction = residual_prediction
524
525
526
527
528
529
530
531
532
533
534
535

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

536
        # compute downsampled image size. We assume that the latitude-grid includes both poles
537
        self.h = (self.img_size[0] - 1) // scale_factor + 1
538
539
540
541
542
543
        self.w = self.img_size[1] // scale_factor

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
544
545
546
547
548
549
550
551
552
553
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
554
        else:
Boris Bonev's avatar
Boris Bonev committed
555
            raise ValueError(f"Unknown position embedding type {pos_embed}")
556
557

        # encoder
Boris Bonev's avatar
Boris Bonev committed
558
559
560
        self.encoder = DiscreteContinuousEncoder(
            in_shape=self.img_size,
            out_shape=(self.h, self.w),
561
            grid_in=grid,
562
            grid_out=grid_internal,
Boris Bonev's avatar
Boris Bonev committed
563
564
565
566
567
            inp_chans=self.in_chans,
            out_chans=self.embed_dim,
            kernel_shape=self.encoder_kernel_shape,
            basis_type=filter_basis_type,
            groups=1,
568
569
570
            bias=False,
        )

571
572
573
574
575
576
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
        modes_lon = (self.w // 2 + 1) - 1

        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
577

578
579
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
580
581
582
583
584
585
586
587
588

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            block = SphericalNeuralOperatorBlock(
                self.trans,
                self.itrans,
                self.embed_dim,
                self.embed_dim,
Boris Bonev's avatar
Boris Bonev committed
589
                conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
590
591
592
593
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
594
                norm_layer=self.normalization_layer,
595
596
                use_mlp=use_mlp,
                disco_kernel_shape=kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
597
                disco_basis_type=filter_basis_type,
Boris Bonev's avatar
Boris Bonev committed
598
                bias=bias,
599
600
601
602
603
604
            )

            self.blocks.append(block)

        # decoder
        self.decoder = DiscreteContinuousDecoder(
605
            in_shape=(self.h, self.w),
606
            out_shape=self.img_size,
607
            grid_in=grid_internal,
608
609
610
611
            grid_out=grid,
            inp_chans=self.embed_dim,
            out_chans=self.out_chans,
            kernel_shape=self.encoder_kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
612
            basis_type=filter_basis_type,
613
614
            groups=1,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
615
            upsample_sht=upsample_sht,
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}

    def forward_features(self, x):
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        return x

    def forward(self, x):
apaaris's avatar
apaaris committed
631
632
633
634
        """
        Forward pass through the complete LSNO model.
        
        Parameters
635
        ----------
apaaris's avatar
apaaris committed
636
637
638
639
640
641
642
643
        x : torch.Tensor
            Input tensor of shape (batch_size, in_chans, height, width)
            
        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, out_chans, height, width)
        """
Boris Bonev's avatar
Boris Bonev committed
644
        if self.residual_prediction:
645
646
647
648
649
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
650
            x = self.pos_embed(x)
651
652
653
654
655

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
656
        if self.residual_prediction:
657
658
659
            x = x + residual

        return x