lsno.py 21.8 KB
Newer Older
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
import math

34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
40
from torch_harmonics import ResampleS2
41

Boris Bonev's avatar
Boris Bonev committed
42
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
43
44
45

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
46
47
48
49
50
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
    theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}

    return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
51
52

class DiscreteContinuousEncoder(nn.Module):
Andrea Paris's avatar
Andrea Paris committed
53
    """
apaaris's avatar
apaaris committed
54
55
56
57
58
59
    Discrete-continuous encoder for spherical neural operators.
    
    This module performs downsampling using discrete-continuous convolutions on the sphere,
    reducing the spatial resolution while maintaining the spectral properties of the data.
    
    Parameters
60
    ----------
apaaris's avatar
apaaris committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (721, 1440)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (480, 960)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    """
82
83
    def __init__(
        self,
84
        in_shape=(721, 1440),
85
86
87
88
89
        out_shape=(480, 960),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
90
91
        kernel_shape=(3, 3),
        basis_type="morlet",
92
93
94
95
96
97
98
99
100
        groups=1,
        bias=False,
    ):
        super().__init__()

        # set up local convolution
        self.conv = DiscreteContinuousConvS2(
            inp_chans,
            out_chans,
101
            in_shape=in_shape,
102
103
            out_shape=out_shape,
            kernel_shape=kernel_shape,
104
            basis_type=basis_type,
105
106
107
108
            grid_in=grid_in,
            grid_out=grid_out,
            groups=groups,
            bias=bias,
Boris Bonev's avatar
Boris Bonev committed
109
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
110
111
112
        )

    def forward(self, x):
113

114
115
116
117
118
119
120
121
122
123
124
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
            x = self.conv(x)
            x = x.to(dtype=dtype)

        return x


class DiscreteContinuousDecoder(nn.Module):
Andrea Paris's avatar
Andrea Paris committed
125
    """
apaaris's avatar
apaaris committed
126
127
128
129
130
131
    Discrete-continuous decoder for spherical neural operators.
    
    This module performs upsampling using either spherical harmonic transforms or resampling,
    followed by discrete-continuous convolutions to restore spatial resolution.
    
    Parameters
132
    ----------
apaaris's avatar
apaaris committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (480, 960)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (721, 1440)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    inp_chans : int, optional
        Number of input channels, by default 2
    out_chans : int, optional
        Number of output channels, by default 2
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    groups : int, optional
        Number of groups for grouped convolution, by default 1
    bias : bool, optional
        Whether to use bias, by default False
    upsample_sht : bool, optional
        Whether to use SHT for upsampling, by default False
    """
156
157
    def __init__(
        self,
158
        in_shape=(480, 960),
159
160
161
162
163
        out_shape=(721, 1440),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
164
165
        kernel_shape=(3, 3),
        basis_type="morlet",
166
167
        groups=1,
        bias=False,
Boris Bonev's avatar
Boris Bonev committed
168
        upsample_sht=False,
169
170
171
    ):
        super().__init__()

172
173
174
175
176
177
178
        # set up upsampling
        if upsample_sht:
            self.sht = RealSHT(*in_shape, grid=grid_in).float()
            self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
            self.upsample = nn.Sequential(self.sht, self.isht)
        else:
            self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
179
180

        # set up DISCO convolution
181
        self.conv = DiscreteContinuousConvS2(
182
183
184
185
186
            inp_chans,
            out_chans,
            in_shape=out_shape,
            out_shape=out_shape,
            kernel_shape=kernel_shape,
187
            basis_type=basis_type,
188
189
190
191
            grid_in=grid_out,
            grid_out=grid_out,
            groups=groups,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
192
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
193
194
195
        )

    def forward(self, x):
196

197
198
199
200
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
201
            x = self.upsample(x)
202
            x = self.conv(x)
203
204
205
206
207
208
209
210
            x = x.to(dtype=dtype)

        return x


class SphericalNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    Parameters
    ----------
    forward_transform : torch.nn.Module
        Forward transform to use for the block
    inverse_transform : torch.nn.Module
        Inverse transform to use for the block
    input_dim : int
        Input dimension
    output_dim : int
        Output dimension
    conv_type : str, optional
        Type of convolution to use, by default "local"
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path : float, optional
        Drop path rate, by default 0.0
    act_layer : torch.nn.Module, optional
        Activation function to use, by default nn.GELU
    norm_layer : str, optional
        Type of normalization to use, by default "none"
    inner_skip : str, optional
        Type of inner skip connection to use, by default "none"
    outer_skip : str, optional
        Type of outer skip connection to use, by default "identity"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    disco_kernel_shape : tuple, optional
        Kernel shape for discrete-continuous convolution, by default (3, 3)
    disco_basis_type : str, optional
        Filter basis type for discrete-continuous convolution, by default "morlet"
    bias : bool, optional
        Whether to use bias, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
251
252
253
254
255
256
257
258
259
260
261
262
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        conv_type="local",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
263
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
264
        norm_layer="none",
265
266
        inner_skip="none",
        outer_skip="identity",
267
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
268
269
270
        disco_kernel_shape=(3, 3),
        disco_basis_type="morlet",
        bias=False,
271
272
273
274
275
276
277
278
279
280
281
282
283
    ):
        super().__init__()

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        # convolution layer
        if conv_type == "local":
Boris Bonev's avatar
Boris Bonev committed
284
            theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
285
286
287
288
289
290
            self.local_conv = DiscreteContinuousConvS2(
                input_dim,
                output_dim,
                in_shape=(forward_transform.nlat, forward_transform.nlon),
                out_shape=(inverse_transform.nlat, inverse_transform.nlon),
                kernel_shape=disco_kernel_shape,
291
                basis_type=disco_basis_type,
292
293
                grid_in=forward_transform.grid,
                grid_out=inverse_transform.grid,
Boris Bonev's avatar
Boris Bonev committed
294
295
                bias=bias,
                theta_cutoff=theta_cutoff,
296
297
            )
        elif conv_type == "global":
Boris Bonev's avatar
Boris Bonev committed
298
            self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        else:
            raise ValueError(f"Unknown convolution type {conv_type}")

        if inner_skip == "linear":
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif inner_skip == "identity":
            assert input_dim == output_dim
            self.inner_skip = nn.Identity()
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")

Boris Bonev's avatar
Boris Bonev committed
313
314
315
316
317
318
319
320
321
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        if use_mlp == True:
            mlp_hidden_dim = int(output_dim * mlp_ratio)
            self.mlp = MLP(
                in_features=output_dim,
                out_features=input_dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop_rate=drop_rate,
                checkpointing=False,
                gain=gain_factor,
            )

        if outer_skip == "linear":
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif outer_skip == "identity":
            assert input_dim == output_dim
            self.outer_skip = nn.Identity()
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")

    def forward(self, x):

        residual = x

        if hasattr(self, "global_conv"):
            x, _ = self.global_conv(x)
        elif hasattr(self, "local_conv"):
            x = self.local_conv(x)

Boris Bonev's avatar
Boris Bonev committed
362
        x = self.norm(x)
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

        if hasattr(self, "inner_skip"):
            x = x + self.inner_skip(residual)

        if hasattr(self, "mlp"):
            x = self.mlp(x)

        x = self.drop_path(x)

        if hasattr(self, "outer_skip"):
            x = x + self.outer_skip(residual)

        return x


Boris Bonev's avatar
Boris Bonev committed
378
class LocalSphericalNeuralOperator(nn.Module):
Andrea Paris's avatar
Andrea Paris committed
379
    """
380
381
382
383
    LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
    operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
    Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
    as well as in the encoder and decoders.
apaaris's avatar
apaaris committed
384
    
385
    Parameters
386
    ----------
apaaris's avatar
apaaris committed
387
388
389
390
391
392
    img_size : tuple, optional
        Input image size (nlat, nlon), by default (128, 256)
    grid : str, optional
        Grid type for input/output, by default "equiangular"
    grid_internal : str, optional
        Grid type for internal processing, by default "legendre-gauss"
393
    scale_factor : int, optional
apaaris's avatar
apaaris committed
394
        Scale factor for resolution changes, by default 3
395
396
397
398
399
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
apaaris's avatar
apaaris committed
400
        Embedding dimension, by default 256
401
    num_layers : int, optional
apaaris's avatar
apaaris committed
402
        Number of layers, by default 4
403
    activation_function : str, optional
apaaris's avatar
apaaris committed
404
405
406
407
408
409
410
411
412
413
414
        Activation function name, by default "gelu"
    kernel_shape : tuple, optional
        Kernel shape for convolutions, by default (3, 3)
    encoder_kernel_shape : tuple, optional
        Kernel shape for encoder, by default (3, 3)
    filter_basis_type : str, optional
        Filter basis type, by default "morlet"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
415
416
417
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
apaaris's avatar
apaaris committed
418
        Drop path rate, by default 0.0
419
420
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Boris Bonev's avatar
Boris Bonev committed
421
    sfno_block_frequency : int, optional
apaaris's avatar
apaaris committed
422
        Frequency of SFNO blocks, by default 2
423
    hard_thresholding_fraction : float, optional
apaaris's avatar
apaaris committed
424
425
426
427
428
        Hard thresholding fraction, by default 1.0
    residual_prediction : bool, optional
        Whether to use residual prediction, by default False
    pos_embed : str, optional
        Position embedding type, by default "none"
429
430
    upsample_sht : bool, optional
        Use SHT upsampling if true, else linear interpolation
Boris Bonev's avatar
Boris Bonev committed
431
432
    bias : bool, optional
        Whether to use a bias, by default False
433

434
    Example
435
    ----------
436
    >>> model = LocalSphericalNeuralOperator(
437
438
439
440
441
442
443
444
445
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=4,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
446
447

    References
448
    ----------
449
450
451
452
453
454
455
456
    .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
        "Neural Operators with Localized Integral and Differential Kernels" (2024).
        ICML 2024, https://arxiv.org/pdf/2402.16845.

    .. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.

457
458
459
460
461
462
    """

    def __init__(
        self,
        img_size=(128, 256),
        grid="equiangular",
463
        grid_internal="legendre-gauss",
464
        scale_factor=3,
465
466
467
468
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
469
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
470
471
472
        kernel_shape=(3, 3),
        encoder_kernel_shape=(3, 3),
        filter_basis_type="morlet",
473
474
475
476
477
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
Boris Bonev's avatar
Boris Bonev committed
478
        sfno_block_frequency=2,
479
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
480
481
        residual_prediction=False,
        pos_embed="none",
482
        upsample_sht=False,
Boris Bonev's avatar
Boris Bonev committed
483
        bias=False,
484
485
486
487
488
    ):
        super().__init__()

        self.img_size = img_size
        self.grid = grid
489
        self.grid_internal = grid_internal
490
491
492
493
494
495
496
497
498
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.encoder_kernel_shape = encoder_kernel_shape
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
Boris Bonev's avatar
Boris Bonev committed
499
        self.residual_prediction = residual_prediction
500
501
502
503
504
505
506
507
508
509
510
511

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

512
        # compute downsampled image size. We assume that the latitude-grid includes both poles
513
        self.h = (self.img_size[0] - 1) // scale_factor + 1
514
515
516
517
518
519
        self.w = self.img_size[1] // scale_factor

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
520
521
522
523
524
525
526
527
528
529
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
530
        else:
Boris Bonev's avatar
Boris Bonev committed
531
            raise ValueError(f"Unknown position embedding type {pos_embed}")
532
533

        # encoder
Boris Bonev's avatar
Boris Bonev committed
534
535
536
        self.encoder = DiscreteContinuousEncoder(
            in_shape=self.img_size,
            out_shape=(self.h, self.w),
537
            grid_in=grid,
538
            grid_out=grid_internal,
Boris Bonev's avatar
Boris Bonev committed
539
540
541
542
543
            inp_chans=self.in_chans,
            out_chans=self.embed_dim,
            kernel_shape=self.encoder_kernel_shape,
            basis_type=filter_basis_type,
            groups=1,
544
545
546
            bias=False,
        )

547
548
549
550
551
552
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
        modes_lon = (self.w // 2 + 1) - 1

        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
553

554
555
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
556
557
558
559
560
561
562
563
564

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            block = SphericalNeuralOperatorBlock(
                self.trans,
                self.itrans,
                self.embed_dim,
                self.embed_dim,
Boris Bonev's avatar
Boris Bonev committed
565
                conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
566
567
568
569
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
570
                norm_layer=self.normalization_layer,
571
572
                use_mlp=use_mlp,
                disco_kernel_shape=kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
573
                disco_basis_type=filter_basis_type,
Boris Bonev's avatar
Boris Bonev committed
574
                bias=bias,
575
576
577
578
579
580
            )

            self.blocks.append(block)

        # decoder
        self.decoder = DiscreteContinuousDecoder(
581
            in_shape=(self.h, self.w),
582
            out_shape=self.img_size,
583
            grid_in=grid_internal,
584
585
586
587
            grid_out=grid,
            inp_chans=self.embed_dim,
            out_chans=self.out_chans,
            kernel_shape=self.encoder_kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
588
            basis_type=filter_basis_type,
589
590
            groups=1,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
591
            upsample_sht=upsample_sht,
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}

    def forward_features(self, x):
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        return x

    def forward(self, x):
607

Boris Bonev's avatar
Boris Bonev committed
608
        if self.residual_prediction:
609
610
611
612
613
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
614
            x = self.pos_embed(x)
615
616
617
618
619

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
620
        if self.residual_prediction:
621
622
623
            x = x + residual

        return x