sfno.py 15.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
Boris Bonev's avatar
Boris Bonev committed
31
import math
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35

36
from torch_harmonics import RealSHT, InverseRealSHT
Boris Bonev's avatar
Boris Bonev committed
37

Boris Bonev's avatar
Boris Bonev committed
38
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
Boris Bonev's avatar
Boris Bonev committed
39

Srikumar Sastry's avatar
Srikumar Sastry committed
40
41
from functools import partial

42

Boris Bonev's avatar
Boris Bonev committed
43
44
45
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    Parameters
    ----------
    forward_transform : torch.nn.Module
        Forward transform to use for the block
    inverse_transform : torch.nn.Module
        Inverse transform to use for the block
    input_dim : int
        Input dimension
    output_dim : int
        Output dimension
    mlp_ratio : float, optional
        MLP expansion ratio, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path : float, optional
        Drop path rate, by default 0.0
    act_layer : torch.nn.Module, optional
        Activation function to use, by default nn.GELU
    norm_layer : str, optional
        Type of normalization to use, by default "none"
    inner_skip : str, optional
        Type of inner skip connection to use, by default "none"
    outer_skip : str, optional
        Type of outer skip connection to use, by default "identity"
    use_mlp : bool, optional
        Whether to use MLP layers, by default True
    bias : bool, optional
        Whether to use bias, by default False

    Returns
    -------
    torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
80
    """
81

Boris Bonev's avatar
Boris Bonev committed
82
    def __init__(
83
84
85
86
87
88
89
90
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
91
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
92
        norm_layer="none",
93
94
        inner_skip="none",
        outer_skip="identity",
95
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
96
        bias=False,
97
    ):
98
        super().__init__()
99
100
101
102
103
104
105
106

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
107

Boris Bonev's avatar
Boris Bonev committed
108
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
        if inner_skip == "linear":
111
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
112
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
113
        elif inner_skip == "identity":
114
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
115
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
120

Boris Bonev's avatar
Boris Bonev committed
121
122
123
124
125
126
127
128
129
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
130

Boris Bonev's avatar
Boris Bonev committed
131
        # dropout
132
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
133
134
135

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
136
            gain_factor /= 2.0
137

Boris Bonev's avatar
Boris Bonev committed
138
        if use_mlp == True:
139
            mlp_hidden_dim = int(output_dim * mlp_ratio)
140
141
142
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
143

Boris Bonev's avatar
Boris Bonev committed
144
        if outer_skip == "linear":
145
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
146
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
147
        elif outer_skip == "identity":
148
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
149
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
154
155

    def forward(self, x):
156

157
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
158

Boris Bonev's avatar
Boris Bonev committed
159
        x = self.norm(x)
160

Boris Bonev's avatar
Boris Bonev committed
161
        if hasattr(self, "inner_skip"):
162
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
163

Boris Bonev's avatar
Boris Bonev committed
164
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
            x = self.mlp(x)

        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
169
        if hasattr(self, "outer_skip"):
170
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
171

Boris Bonev's avatar
Boris Bonev committed
172
173
        return x

174

Boris Bonev's avatar
Boris Bonev committed
175
class SphericalFourierNeuralOperator(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
176
    """
177
178
179
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
180
181
182

    Parameters
    ----------
apaaris's avatar
apaaris committed
183
    img_size : tuple, optional
Boris Bonev's avatar
Boris Bonev committed
184
        Shape of the input channels, by default (128, 256)
apaaris's avatar
apaaris committed
185
186
187
188
    grid : str, optional
        Input grid type, by default "equiangular"
    grid_internal : str, optional
        Internal grid type for computations, by default "legendre-gauss"
Boris Bonev's avatar
Boris Bonev committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
204
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
205
206
207
208
209
210
211
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
apaaris's avatar
apaaris committed
212
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "none"
Boris Bonev's avatar
Boris Bonev committed
213
214
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
Boris Bonev's avatar
Boris Bonev committed
215
    residual_prediction : bool, optional
apaaris's avatar
apaaris committed
216
217
218
        Whether to add a single large skip connection, by default False
    pos_embed : str, optional
        Type of positional embedding to use, by default "none"
Boris Bonev's avatar
Boris Bonev committed
219
220
    bias : bool, optional
        Whether to use a bias, by default False
Boris Bonev's avatar
Boris Bonev committed
221
222

    Example:
223
    ----------
Boris Bonev's avatar
Boris Bonev committed
224
    >>> model = SphericalFourierNeuralOperator(
apaaris's avatar
apaaris committed
225
    ...         img_size=(128, 256),
Boris Bonev's avatar
Boris Bonev committed
226
227
228
229
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
230
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
231
232
233
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
234
235

    References
236
    ----------
237
238
239
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
240
241
242
    """

    def __init__(
243
244
245
        self,
        img_size=(128, 256),
        grid="equiangular",
246
        grid_internal="legendre-gauss",
247
248
249
250
251
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
252
        activation_function="gelu",
253
254
255
256
257
258
259
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
260
261
262
        residual_prediction=False,
        pos_embed="none",
        bias=False,
263
    ):
Boris Bonev's avatar
Boris Bonev committed
264

265
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
266
267

        self.img_size = img_size
268
        self.grid = grid
269
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
270
271
272
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
273
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
274
275
276
277
278
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
Boris Bonev's avatar
Boris Bonev committed
279
        self.residual_prediction = residual_prediction
Boris Bonev's avatar
Boris Bonev committed
280
281

        # activation function
Boris Bonev's avatar
Boris Bonev committed
282
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
283
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
284
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
285
            self.activation_function = nn.GELU
286
287
288
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
289
290
291
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

292
        # compute downsampled image size. We assume that the latitude-grid includes both poles
293
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
294
295
296
        self.w = self.img_size[1] // scale_factor

        # dropout
297
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
298
299
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
300
301
302
303
304
305
306
307
308
309
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
310
        else:
Boris Bonev's avatar
Boris Bonev committed
311
            raise ValueError(f"Unknown position embedding type {pos_embed}")
Boris Bonev's avatar
Boris Bonev committed
312

313
314
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
315
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
316
317
        current_dim = self.in_chans
        encoder_layers = []
318
        for l in range(num_encoder_layers - 1):
319
320
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
321
322
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
323
324
325
326
327
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
328
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
329
330
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
331
332
333
334
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
335

336
337
338
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
339
        modes_lon = (self.w // 2 + 1) - 1
Boris Bonev's avatar
Boris Bonev committed
340

341
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
342

343
344
345
346
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
347
348
349
350
351

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
352
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
353

354
            block = SphericalFourierNeuralOperatorBlock(
355
356
                self.trans_down if first_layer else self.trans,
                self.itrans_up if last_layer else self.itrans,
357
358
359
360
361
362
                self.embed_dim,
                self.embed_dim,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
363
                norm_layer=self.normalization_layer,
364
                use_mlp=use_mlp,
Boris Bonev's avatar
Boris Bonev committed
365
                bias=bias,
366
            )
Boris Bonev's avatar
Boris Bonev committed
367
368
369

            self.blocks.append(block)

370
371
372
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
Boris Bonev's avatar
Boris Bonev committed
373
        current_dim = self.embed_dim
374
        decoder_layers = []
375
        for l in range(num_decoder_layers - 1):
376
377
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
378
379
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
380
381
382
383
384
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
385
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
386
387
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
388
389
390
391
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
392
393
394

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
395
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
396
397

    def forward_features(self, x):
398

Boris Bonev's avatar
Boris Bonev committed
399
400
401
402
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
403

Boris Bonev's avatar
Boris Bonev committed
404
405
406
        return x

    def forward(self, x):
407

Boris Bonev's avatar
Boris Bonev committed
408
        if self.residual_prediction:
Boris Bonev's avatar
Boris Bonev committed
409
410
411
412
413
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
414
            x = self.pos_embed(x)
Boris Bonev's avatar
Boris Bonev committed
415
416
417
418
419

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
420
421
422
        if self.residual_prediction:
            x = x + residual

Boris Bonev's avatar
Boris Bonev committed
423
        return x