sfno.py 15.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
Boris Bonev's avatar
Boris Bonev committed
31
import math
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35

36
from torch_harmonics import RealSHT, InverseRealSHT
Boris Bonev's avatar
Boris Bonev committed
37

Boris Bonev's avatar
Boris Bonev committed
38
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
Boris Bonev's avatar
Boris Bonev committed
39

Srikumar Sastry's avatar
Srikumar Sastry committed
40
41
from functools import partial

42

Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
47

Boris Bonev's avatar
Boris Bonev committed
48
    def __init__(
49
50
51
52
53
54
55
56
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
57
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
58
        norm_layer="none",
59
60
        inner_skip="none",
        outer_skip="identity",
61
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
62
        bias=False,
63
    ):
64
        super().__init__()
65
66
67
68
69
70
71
72

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
73

Boris Bonev's avatar
Boris Bonev committed
74
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
Boris Bonev's avatar
Boris Bonev committed
75

Boris Bonev's avatar
Boris Bonev committed
76
        if inner_skip == "linear":
77
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
78
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
79
        elif inner_skip == "identity":
80
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
81
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
86

Boris Bonev's avatar
Boris Bonev committed
87
88
89
90
91
92
93
94
95
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
96

Boris Bonev's avatar
Boris Bonev committed
97
        # dropout
98
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
99
100
101

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
102
            gain_factor /= 2.0
103

Boris Bonev's avatar
Boris Bonev committed
104
        if use_mlp == True:
105
            mlp_hidden_dim = int(output_dim * mlp_ratio)
106
107
108
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
        if outer_skip == "linear":
111
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
112
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
113
        elif outer_skip == "identity":
114
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
115
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
120
121

    def forward(self, x):
apaaris's avatar
apaaris committed
122
123
124
125
126
127
128
129
130
131
132
133
134
        """
        Forward pass through the SFNO block.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Output tensor after processing through the block
        """
135
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
136

Boris Bonev's avatar
Boris Bonev committed
137
        x = self.norm(x)
138

Boris Bonev's avatar
Boris Bonev committed
139
        if hasattr(self, "inner_skip"):
140
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
141

Boris Bonev's avatar
Boris Bonev committed
142
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
143
144
145
146
            x = self.mlp(x)

        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
147
        if hasattr(self, "outer_skip"):
148
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
149

Boris Bonev's avatar
Boris Bonev committed
150
151
        return x

152

Boris Bonev's avatar
Boris Bonev committed
153
class SphericalFourierNeuralOperator(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
154
    """
155
156
157
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
158
159
160

    Parameters
    ----------
apaaris's avatar
apaaris committed
161
    img_size : tuple, optional
Boris Bonev's avatar
Boris Bonev committed
162
        Shape of the input channels, by default (128, 256)
apaaris's avatar
apaaris committed
163
164
165
166
    grid : str, optional
        Input grid type, by default "equiangular"
    grid_internal : str, optional
        Internal grid type for computations, by default "legendre-gauss"
Boris Bonev's avatar
Boris Bonev committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
182
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
183
184
185
186
187
188
189
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
apaaris's avatar
apaaris committed
190
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
Boris Bonev's avatar
Boris Bonev committed
191
192
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
Boris Bonev's avatar
Boris Bonev committed
193
    residual_prediction : bool, optional
apaaris's avatar
apaaris committed
194
195
196
        Whether to add a single large skip connection, by default False
    pos_embed : str, optional
        Type of positional embedding to use, by default "none"
Boris Bonev's avatar
Boris Bonev committed
197
198
    bias : bool, optional
        Whether to use a bias, by default False
Boris Bonev's avatar
Boris Bonev committed
199
200
201

    Example:
    --------
Boris Bonev's avatar
Boris Bonev committed
202
    >>> model = SphericalFourierNeuralOperator(
apaaris's avatar
apaaris committed
203
    ...         img_size=(128, 256),
Boris Bonev's avatar
Boris Bonev committed
204
205
206
207
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
208
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
209
210
211
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
212
213
214
215
216
217

    References
    -----------
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
218
219
220
    """

    def __init__(
221
222
223
        self,
        img_size=(128, 256),
        grid="equiangular",
224
        grid_internal="legendre-gauss",
225
226
227
228
229
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
230
        activation_function="gelu",
231
232
233
234
235
236
237
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
238
239
240
        residual_prediction=False,
        pos_embed="none",
        bias=False,
241
    ):
Boris Bonev's avatar
Boris Bonev committed
242

243
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
244
245

        self.img_size = img_size
246
        self.grid = grid
247
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
248
249
250
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
251
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
252
253
254
255
256
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
Boris Bonev's avatar
Boris Bonev committed
257
        self.residual_prediction = residual_prediction
Boris Bonev's avatar
Boris Bonev committed
258
259

        # activation function
Boris Bonev's avatar
Boris Bonev committed
260
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
261
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
262
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
263
            self.activation_function = nn.GELU
264
265
266
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
267
268
269
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

270
        # compute downsampled image size. We assume that the latitude-grid includes both poles
271
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
272
273
274
        self.w = self.img_size[1] // scale_factor

        # dropout
275
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
276
277
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
278
279
280
281
282
283
284
285
286
287
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
288
        else:
Boris Bonev's avatar
Boris Bonev committed
289
            raise ValueError(f"Unknown position embedding type {pos_embed}")
Boris Bonev's avatar
Boris Bonev committed
290

291
292
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
293
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
294
295
        current_dim = self.in_chans
        encoder_layers = []
296
        for l in range(num_encoder_layers - 1):
297
298
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
299
300
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
301
302
303
304
305
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
306
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
307
308
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
309
310
311
312
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
313

314
315
316
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
317
        modes_lon = (self.w // 2 + 1) - 1
Boris Bonev's avatar
Boris Bonev committed
318

319
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
320

321
322
323
324
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
330
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
331

332
            block = SphericalFourierNeuralOperatorBlock(
333
334
                self.trans_down if first_layer else self.trans,
                self.itrans_up if last_layer else self.itrans,
335
336
337
338
339
340
                self.embed_dim,
                self.embed_dim,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
341
                norm_layer=self.normalization_layer,
342
                use_mlp=use_mlp,
Boris Bonev's avatar
Boris Bonev committed
343
                bias=bias,
344
            )
Boris Bonev's avatar
Boris Bonev committed
345
346
347

            self.blocks.append(block)

348
349
350
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
Boris Bonev's avatar
Boris Bonev committed
351
        current_dim = self.embed_dim
352
        decoder_layers = []
353
        for l in range(num_decoder_layers - 1):
354
355
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
356
357
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
358
359
360
361
362
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
363
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
364
365
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
366
367
368
369
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
370
371
372

    @torch.jit.ignore
    def no_weight_decay(self):
apaaris's avatar
apaaris committed
373
374
375
376
377
378
379
380
381
        """
        Return a set of parameter names that should not be decayed.
        
        Returns
        -------
        set
            Set of parameter names to exclude from weight decay
        """
        return {"pos_embed.pos_embed"}
Boris Bonev's avatar
Boris Bonev committed
382
383

    def forward_features(self, x):
apaaris's avatar
apaaris committed
384
385
386
387
388
389
390
391
392
393
394
395
396
        """
        Forward pass through the feature extraction layers.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Features after processing through the network
        """
Boris Bonev's avatar
Boris Bonev committed
397
398
399
400
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
401

Boris Bonev's avatar
Boris Bonev committed
402
403
404
        return x

    def forward(self, x):
apaaris's avatar
apaaris committed
405
406
407
408
409
410
411
412
413
414
415
416
417
        """
        Forward pass through the complete SFNO model.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, in_chans, height, width)
            
        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, out_chans, height, width)
        """
Boris Bonev's avatar
Boris Bonev committed
418
        if self.residual_prediction:
Boris Bonev's avatar
Boris Bonev committed
419
420
421
422
423
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
424
            x = self.pos_embed(x)
Boris Bonev's avatar
Boris Bonev committed
425
426
427
428
429

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
430
431
432
        if self.residual_prediction:
            x = x + residual

Boris Bonev's avatar
Boris Bonev committed
433
        return x