sfno.py 16.8 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
Boris Bonev's avatar
Boris Bonev committed
31
import math
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35

36
from torch_harmonics import RealSHT, InverseRealSHT
Boris Bonev's avatar
Boris Bonev committed
37

Boris Bonev's avatar
Boris Bonev committed
38
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
Boris Bonev's avatar
Boris Bonev committed
39

Srikumar Sastry's avatar
Srikumar Sastry committed
40
41
from functools import partial

42

Boris Bonev's avatar
Boris Bonev committed
43
44
45
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    Parameters
    ----------
    forward_transform : torch.nn.Module
        Forward transform to use for the block
    inverse_transform : torch.nn.Module
        Inverse transform to use for the block
    input_dim : int
        Input dimension
    output_dim : int
        Output dimension
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path : float, optional
        Drop path rate, by default 0.0
    act_layer : torch.nn.Module, optional
        Activation function to use, by default nn.GELU
    norm_layer : str, optional
        Type of normalization to use, by default "none"
    inner_skip : str, optional
        Type of inner skip connection to use, by default "none"
    outer_skip : str, optional
        Type of outer skip connection to use, by default "identity"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    bias : bool, optional
        Whether to use bias, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
80
    """
81

Boris Bonev's avatar
Boris Bonev committed
82
    def __init__(
83
84
85
86
87
88
89
90
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
91
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
92
        norm_layer="none",
93
94
        inner_skip="none",
        outer_skip="identity",
95
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
96
        bias=False,
97
    ):
98
        super().__init__()
99
100
101
102
103
104
105
106

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
107

Boris Bonev's avatar
Boris Bonev committed
108
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
        if inner_skip == "linear":
111
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
112
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
113
        elif inner_skip == "identity":
114
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
115
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
120

Boris Bonev's avatar
Boris Bonev committed
121
122
123
124
125
126
127
128
129
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
130

Boris Bonev's avatar
Boris Bonev committed
131
        # dropout
132
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
133
134
135

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
136
            gain_factor /= 2.0
137

Boris Bonev's avatar
Boris Bonev committed
138
        if use_mlp == True:
139
            mlp_hidden_dim = int(output_dim * mlp_ratio)
140
141
142
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
143

Boris Bonev's avatar
Boris Bonev committed
144
        if outer_skip == "linear":
145
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
146
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
147
        elif outer_skip == "identity":
148
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
149
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
154
155

    def forward(self, x):
apaaris's avatar
apaaris committed
156
157
158
159
        """
        Forward pass through the SFNO block.
        
        Parameters
160
        ----------
apaaris's avatar
apaaris committed
161
162
163
164
        x : torch.Tensor
            Input tensor
            
        Returns
165
        ----------
apaaris's avatar
apaaris committed
166
167
168
        torch.Tensor
            Output tensor after processing through the block
        """
169
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
170

Boris Bonev's avatar
Boris Bonev committed
171
        x = self.norm(x)
172

Boris Bonev's avatar
Boris Bonev committed
173
        if hasattr(self, "inner_skip"):
174
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
175

Boris Bonev's avatar
Boris Bonev committed
176
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
177
178
179
180
            x = self.mlp(x)

        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
181
        if hasattr(self, "outer_skip"):
182
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
183

Boris Bonev's avatar
Boris Bonev committed
184
185
        return x

186

Boris Bonev's avatar
Boris Bonev committed
187
class SphericalFourierNeuralOperator(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
188
    """
189
190
191
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
192
193
194

    Parameters
    ----------
apaaris's avatar
apaaris committed
195
    img_size : tuple, optional
Boris Bonev's avatar
Boris Bonev committed
196
        Shape of the input channels, by default (128, 256)
apaaris's avatar
apaaris committed
197
198
199
200
    grid : str, optional
        Input grid type, by default "equiangular"
    grid_internal : str, optional
        Internal grid type for computations, by default "legendre-gauss"
Boris Bonev's avatar
Boris Bonev committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
216
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
217
218
219
220
221
222
223
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
apaaris's avatar
apaaris committed
224
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
Boris Bonev's avatar
Boris Bonev committed
225
226
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
Boris Bonev's avatar
Boris Bonev committed
227
    residual_prediction : bool, optional
apaaris's avatar
apaaris committed
228
229
230
        Whether to add a single large skip connection, by default False
    pos_embed : str, optional
        Type of positional embedding to use, by default "none"
Boris Bonev's avatar
Boris Bonev committed
231
232
    bias : bool, optional
        Whether to use a bias, by default False
Boris Bonev's avatar
Boris Bonev committed
233
234

    Example:
235
    ----------
Boris Bonev's avatar
Boris Bonev committed
236
    >>> model = SphericalFourierNeuralOperator(
apaaris's avatar
apaaris committed
237
    ...         img_size=(128, 256),
Boris Bonev's avatar
Boris Bonev committed
238
239
240
241
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
242
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
243
244
245
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
246
247

    References
248
    ----------
249
250
251
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
252
253
254
    """

    def __init__(
255
256
257
        self,
        img_size=(128, 256),
        grid="equiangular",
258
        grid_internal="legendre-gauss",
259
260
261
262
263
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
264
        activation_function="gelu",
265
266
267
268
269
270
271
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
272
273
274
        residual_prediction=False,
        pos_embed="none",
        bias=False,
275
    ):
Boris Bonev's avatar
Boris Bonev committed
276

277
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
278
279

        self.img_size = img_size
280
        self.grid = grid
281
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
282
283
284
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
285
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
286
287
288
289
290
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
Boris Bonev's avatar
Boris Bonev committed
291
        self.residual_prediction = residual_prediction
Boris Bonev's avatar
Boris Bonev committed
292
293

        # activation function
Boris Bonev's avatar
Boris Bonev committed
294
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
295
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
296
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
297
            self.activation_function = nn.GELU
298
299
300
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
301
302
303
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

304
        # compute downsampled image size. We assume that the latitude-grid includes both poles
305
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
306
307
308
        self.w = self.img_size[1] // scale_factor

        # dropout
309
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
310
311
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
312
313
314
315
316
317
318
319
320
321
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
322
        else:
Boris Bonev's avatar
Boris Bonev committed
323
            raise ValueError(f"Unknown position embedding type {pos_embed}")
Boris Bonev's avatar
Boris Bonev committed
324

325
326
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
327
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
328
329
        current_dim = self.in_chans
        encoder_layers = []
330
        for l in range(num_encoder_layers - 1):
331
332
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
333
334
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
335
336
337
338
339
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
340
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
341
342
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
343
344
345
346
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
347

348
349
350
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
351
        modes_lon = (self.w // 2 + 1) - 1
Boris Bonev's avatar
Boris Bonev committed
352

353
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
354

355
356
357
358
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
359
360
361
362
363

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
364
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
365

366
            block = SphericalFourierNeuralOperatorBlock(
367
368
                self.trans_down if first_layer else self.trans,
                self.itrans_up if last_layer else self.itrans,
369
370
371
372
373
374
                self.embed_dim,
                self.embed_dim,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
375
                norm_layer=self.normalization_layer,
376
                use_mlp=use_mlp,
Boris Bonev's avatar
Boris Bonev committed
377
                bias=bias,
378
            )
Boris Bonev's avatar
Boris Bonev committed
379
380
381

            self.blocks.append(block)

382
383
384
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
Boris Bonev's avatar
Boris Bonev committed
385
        current_dim = self.embed_dim
386
        decoder_layers = []
387
        for l in range(num_decoder_layers - 1):
388
389
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
390
391
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
392
393
394
395
396
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
397
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
398
399
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
400
401
402
403
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
404
405
406

    @torch.jit.ignore
    def no_weight_decay(self):
apaaris's avatar
apaaris committed
407
408
409
410
411
412
413
414
415
        """
        Return a set of parameter names that should not be decayed.
        
        Returns
        -------
        set
            Set of parameter names to exclude from weight decay
        """
        return {"pos_embed.pos_embed"}
Boris Bonev's avatar
Boris Bonev committed
416
417

    def forward_features(self, x):
apaaris's avatar
apaaris committed
418
419
420
421
        """
        Forward pass through the feature extraction layers.
        
        Parameters
422
        ----------
apaaris's avatar
apaaris committed
423
424
425
426
        x : torch.Tensor
            Input tensor
            
        Returns
427
        ----------
apaaris's avatar
apaaris committed
428
429
430
        torch.Tensor
            Features after processing through the network
        """
Boris Bonev's avatar
Boris Bonev committed
431
432
433
434
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
435

Boris Bonev's avatar
Boris Bonev committed
436
437
438
        return x

    def forward(self, x):
apaaris's avatar
apaaris committed
439
440
441
442
        """
        Forward pass through the complete SFNO model.
        
        Parameters
443
        ----------
apaaris's avatar
apaaris committed
444
445
446
447
        x : torch.Tensor
            Input tensor of shape (batch_size, in_chans, height, width)
            
        Returns
448
        ----------
apaaris's avatar
apaaris committed
449
450
451
        torch.Tensor
            Output tensor of shape (batch_size, out_chans, height, width)
        """
Boris Bonev's avatar
Boris Bonev committed
452
        if self.residual_prediction:
Boris Bonev's avatar
Boris Bonev committed
453
454
455
456
457
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
458
            x = self.pos_embed(x)
Boris Bonev's avatar
Boris Bonev committed
459
460
461
462
463

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
464
465
466
        if self.residual_prediction:
            x = x + residual

Boris Bonev's avatar
Boris Bonev committed
467
        return x