lsno.py 21.4 KB
Newer Older
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
import math

34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
40
from torch_harmonics import ResampleS2
41

Boris Bonev's avatar
Boris Bonev committed
42
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
43
44
45

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
46
47
48
49
50
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
    theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}

    return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
51
52

class DiscreteContinuousEncoder(nn.Module):
apaaris's avatar
apaaris committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    r"""
    Discrete-continuous encoder for spherical neural operators.
    
    This module performs downsampling using discrete-continuous convolutions on the sphere,
    reducing the spatial resolution while maintaining the spectral properties of the data.
    
    Parameters
    -----------
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (721, 1440)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (480, 960)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    """
82
83
    def __init__(
        self,
84
        in_shape=(721, 1440),
85
86
87
88
89
        out_shape=(480, 960),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
90
91
        kernel_shape=(3, 3),
        basis_type="morlet",
92
93
94
95
96
97
98
99
100
        groups=1,
        bias=False,
    ):
        super().__init__()

        # set up local convolution
        self.conv = DiscreteContinuousConvS2(
            inp_chans,
            out_chans,
101
            in_shape=in_shape,
102
103
            out_shape=out_shape,
            kernel_shape=kernel_shape,
104
            basis_type=basis_type,
105
106
107
108
            grid_in=grid_in,
            grid_out=grid_out,
            groups=groups,
            bias=bias,
Boris Bonev's avatar
Boris Bonev committed
109
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
110
111
112
        )

    def forward(self, x):
apaaris's avatar
apaaris committed
113
114
115
116
117
118
119
120
121
122
123
124
125
        """
        Forward pass of the discrete-continuous encoder.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (batch, channels, nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Encoded tensor with reduced spatial resolution
        """
126
127
128
129
130
131
132
133
134
135
136
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
            x = self.conv(x)
            x = x.to(dtype=dtype)

        return x


class DiscreteContinuousDecoder(nn.Module):
apaaris's avatar
apaaris committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    r"""
    Discrete-continuous decoder for spherical neural operators.
    
    This module performs upsampling using either spherical harmonic transforms or resampling,
    followed by discrete-continuous convolutions to restore spatial resolution.
    
    Parameters
    -----------
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (480, 960)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (721, 1440)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    upsample_sht : bool, optional
        Whether to use SHT for upsampling, by default False
    """
168
169
    def __init__(
        self,
170
        in_shape=(480, 960),
171
172
173
174
175
        out_shape=(721, 1440),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
176
177
        kernel_shape=(3, 3),
        basis_type="morlet",
178
179
        groups=1,
        bias=False,
Boris Bonev's avatar
Boris Bonev committed
180
        upsample_sht=False,
181
182
183
    ):
        super().__init__()

184
185
186
187
188
189
190
        # set up upsampling
        if upsample_sht:
            self.sht = RealSHT(*in_shape, grid=grid_in).float()
            self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
            self.upsample = nn.Sequential(self.sht, self.isht)
        else:
            self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
191
192

        # set up DISCO convolution
193
        self.conv = DiscreteContinuousConvS2(
194
195
196
197
198
            inp_chans,
            out_chans,
            in_shape=out_shape,
            out_shape=out_shape,
            kernel_shape=kernel_shape,
199
            basis_type=basis_type,
200
201
202
203
            grid_in=grid_out,
            grid_out=grid_out,
            groups=groups,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
204
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
205
206
207
        )

    def forward(self, x):
apaaris's avatar
apaaris committed
208
209
210
211
212
213
214
215
216
217
218
219
220
        """
        Forward pass of the discrete-continuous decoder.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor with shape (batch, channels, nlat, nlon)
            
        Returns
        -------
        torch.Tensor
            Decoded tensor with restored spatial resolution
        """
221
222
223
224
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
225
            x = self.upsample(x)
226
            x = self.conv(x)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            x = x.to(dtype=dtype)

        return x


class SphericalNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        conv_type="local",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
247
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
248
        norm_layer="none",
249
250
        inner_skip="none",
        outer_skip="identity",
251
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
252
253
254
        disco_kernel_shape=(3, 3),
        disco_basis_type="morlet",
        bias=False,
255
256
257
258
259
260
261
262
263
264
265
266
267
    ):
        super().__init__()

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        # convolution layer
        if conv_type == "local":
Boris Bonev's avatar
Boris Bonev committed
268
            theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
269
270
271
272
273
274
            self.local_conv = DiscreteContinuousConvS2(
                input_dim,
                output_dim,
                in_shape=(forward_transform.nlat, forward_transform.nlon),
                out_shape=(inverse_transform.nlat, inverse_transform.nlon),
                kernel_shape=disco_kernel_shape,
275
                basis_type=disco_basis_type,
276
277
                grid_in=forward_transform.grid,
                grid_out=inverse_transform.grid,
Boris Bonev's avatar
Boris Bonev committed
278
279
                bias=bias,
                theta_cutoff=theta_cutoff,
280
281
            )
        elif conv_type == "global":
Boris Bonev's avatar
Boris Bonev committed
282
            self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        else:
            raise ValueError(f"Unknown convolution type {conv_type}")

        if inner_skip == "linear":
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif inner_skip == "identity":
            assert input_dim == output_dim
            self.inner_skip = nn.Identity()
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")

Boris Bonev's avatar
Boris Bonev committed
297
298
299
300
301
302
303
304
305
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345

        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        if use_mlp == True:
            mlp_hidden_dim = int(output_dim * mlp_ratio)
            self.mlp = MLP(
                in_features=output_dim,
                out_features=input_dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop_rate=drop_rate,
                checkpointing=False,
                gain=gain_factor,
            )

        if outer_skip == "linear":
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif outer_skip == "identity":
            assert input_dim == output_dim
            self.outer_skip = nn.Identity()
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")

    def forward(self, x):

        residual = x

        if hasattr(self, "global_conv"):
            x, _ = self.global_conv(x)
        elif hasattr(self, "local_conv"):
            x = self.local_conv(x)

Boris Bonev's avatar
Boris Bonev committed
346
        x = self.norm(x)
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        if hasattr(self, "inner_skip"):
            x = x + self.inner_skip(residual)

        if hasattr(self, "mlp"):
            x = self.mlp(x)

        x = self.drop_path(x)

        if hasattr(self, "outer_skip"):
            x = x + self.outer_skip(residual)

        return x


Boris Bonev's avatar
Boris Bonev committed
362
class LocalSphericalNeuralOperator(nn.Module):
apaaris's avatar
apaaris committed
363
    r"""
364
365
366
367
    LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
    operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
    Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
    as well as in the encoder and decoders.
apaaris's avatar
apaaris committed
368
    
369
    Parameters
370
    -----------
apaaris's avatar
apaaris committed
371
372
373
374
375
376
    img_size : tuple, optional
        Input image size (nlat, nlon), by default (128, 256)
    grid : str, optional
        Grid type for input/output, by default "equiangular"
    grid_internal : str, optional
        Grid type for internal processing, by default "legendre-gauss"
377
    scale_factor : int, optional
apaaris's avatar
apaaris committed
378
        Scale factor for resolution changes, by default 3
379
380
381
382
383
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
apaaris's avatar
apaaris committed
384
        Embedding dimension, by default 256
385
    num_layers : int, optional
apaaris's avatar
apaaris committed
386
        Number of layers, by default 4
387
    activation_function : str, optional
apaaris's avatar
apaaris committed
388
389
390
391
392
393
394
395
396
397
398
        Activation function name, by default "gelu"
    kernel_shape : tuple, optional
        Kernel shape for convolutions, by default (3, 3)
    encoder_kernel_shape : tuple, optional
        Kernel shape for encoder, by default (3, 3)
    filter_basis_type : str, optional
        Filter basis type, by default "morlet"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
399
400
401
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
apaaris's avatar
apaaris committed
402
        Drop path rate, by default 0.0
403
404
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Boris Bonev's avatar
Boris Bonev committed
405
    sfno_block_frequency : int, optional
apaaris's avatar
apaaris committed
406
        Frequency of SFNO blocks, by default 2
407
    hard_thresholding_fraction : float, optional
apaaris's avatar
apaaris committed
408
409
410
411
412
        Hard thresholding fraction, by default 1.0
    residual_prediction : bool, optional
        Whether to use residual prediction, by default False
    pos_embed : str, optional
        Position embedding type, by default "none"
413
414
    upsample_sht : bool, optional
        Use SHT upsampling if true, else linear interpolation
Boris Bonev's avatar
Boris Bonev committed
415
416
    bias : bool, optional
        Whether to use a bias, by default False
417

418
419
    Example
    -----------
420
    >>> model = LocalSphericalNeuralOperator(
421
422
423
424
425
426
427
428
429
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=4,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
430
431
432
433
434
435
436
437
438
439
440

    References
    -----------
    .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
        "Neural Operators with Localized Integral and Differential Kernels" (2024).
        ICML 2024, https://arxiv.org/pdf/2402.16845.

    .. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.

441
442
443
444
445
446
    """

    def __init__(
        self,
        img_size=(128, 256),
        grid="equiangular",
447
        grid_internal="legendre-gauss",
448
        scale_factor=3,
449
450
451
452
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
453
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
454
455
456
        kernel_shape=(3, 3),
        encoder_kernel_shape=(3, 3),
        filter_basis_type="morlet",
457
458
459
460
461
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
Boris Bonev's avatar
Boris Bonev committed
462
        sfno_block_frequency=2,
463
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
464
465
        residual_prediction=False,
        pos_embed="none",
466
        upsample_sht=False,
Boris Bonev's avatar
Boris Bonev committed
467
        bias=False,
468
469
470
471
472
    ):
        super().__init__()

        self.img_size = img_size
        self.grid = grid
473
        self.grid_internal = grid_internal
474
475
476
477
478
479
480
481
482
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.encoder_kernel_shape = encoder_kernel_shape
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
Boris Bonev's avatar
Boris Bonev committed
483
        self.residual_prediction = residual_prediction
484
485
486
487
488
489
490
491
492
493
494
495

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

496
        # compute downsampled image size. We assume that the latitude-grid includes both poles
497
        self.h = (self.img_size[0] - 1) // scale_factor + 1
498
499
500
501
502
503
        self.w = self.img_size[1] // scale_factor

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
504
505
506
507
508
509
510
511
512
513
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
514
        else:
Boris Bonev's avatar
Boris Bonev committed
515
            raise ValueError(f"Unknown position embedding type {pos_embed}")
516
517

        # encoder
Boris Bonev's avatar
Boris Bonev committed
518
519
520
        self.encoder = DiscreteContinuousEncoder(
            in_shape=self.img_size,
            out_shape=(self.h, self.w),
521
            grid_in=grid,
522
            grid_out=grid_internal,
Boris Bonev's avatar
Boris Bonev committed
523
524
525
526
527
            inp_chans=self.in_chans,
            out_chans=self.embed_dim,
            kernel_shape=self.encoder_kernel_shape,
            basis_type=filter_basis_type,
            groups=1,
528
529
530
            bias=False,
        )

531
532
533
534
535
536
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
        modes_lon = (self.w // 2 + 1) - 1

        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
537

538
539
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
540
541
542
543
544
545
546
547
548

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            block = SphericalNeuralOperatorBlock(
                self.trans,
                self.itrans,
                self.embed_dim,
                self.embed_dim,
Boris Bonev's avatar
Boris Bonev committed
549
                conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
550
551
552
553
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
554
                norm_layer=self.normalization_layer,
555
556
                use_mlp=use_mlp,
                disco_kernel_shape=kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
557
                disco_basis_type=filter_basis_type,
Boris Bonev's avatar
Boris Bonev committed
558
                bias=bias,
559
560
561
562
563
564
            )

            self.blocks.append(block)

        # decoder
        self.decoder = DiscreteContinuousDecoder(
565
            in_shape=(self.h, self.w),
566
            out_shape=self.img_size,
567
            grid_in=grid_internal,
568
569
570
571
            grid_out=grid,
            inp_chans=self.embed_dim,
            out_chans=self.out_chans,
            kernel_shape=self.encoder_kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
572
            basis_type=filter_basis_type,
573
574
            groups=1,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
575
            upsample_sht=upsample_sht,
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}

    def forward_features(self, x):
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        return x

    def forward(self, x):
apaaris's avatar
apaaris committed
591
592
593
594
595
596
597
598
599
600
601
602
603
        """
        Forward pass through the complete LSNO model.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, in_chans, height, width)
            
        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, out_chans, height, width)
        """
Boris Bonev's avatar
Boris Bonev committed
604
        if self.residual_prediction:
605
606
607
608
609
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
610
            x = self.pos_embed(x)
611
612
613
614
615

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
616
        if self.residual_prediction:
617
618
619
            x = x + residual

        return x