sfno.py 14.4 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
Boris Bonev's avatar
Boris Bonev committed
31
import math
Boris Bonev's avatar
Boris Bonev committed
32
33
34

import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35

36
from torch_harmonics import RealSHT, InverseRealSHT
Boris Bonev's avatar
Boris Bonev committed
37

Boris Bonev's avatar
Boris Bonev committed
38
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
Boris Bonev's avatar
Boris Bonev committed
39

Srikumar Sastry's avatar
Srikumar Sastry committed
40
41
from functools import partial

42

Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
47

Boris Bonev's avatar
Boris Bonev committed
48
    def __init__(
49
50
51
52
53
54
55
56
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
57
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
58
        norm_layer="none",
59
60
        inner_skip="none",
        outer_skip="identity",
61
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
62
        bias=False,
63
    ):
64
        super().__init__()
65
66
67
68
69
70
71
72

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
73

Boris Bonev's avatar
Boris Bonev committed
74
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
Boris Bonev's avatar
Boris Bonev committed
75

Boris Bonev's avatar
Boris Bonev committed
76
        if inner_skip == "linear":
77
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
78
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
79
        elif inner_skip == "identity":
80
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
81
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
86

Boris Bonev's avatar
Boris Bonev committed
87
88
89
90
91
92
93
94
95
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {self.norm_layer} not implemented.")
96

Boris Bonev's avatar
Boris Bonev committed
97
        # dropout
98
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
99
100
101

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
102
            gain_factor /= 2.0
103

Boris Bonev's avatar
Boris Bonev committed
104
        if use_mlp == True:
105
            mlp_hidden_dim = int(output_dim * mlp_ratio)
106
107
108
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
        if outer_skip == "linear":
111
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
112
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
113
        elif outer_skip == "identity":
114
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
115
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
120

Boris Bonev's avatar
Boris Bonev committed
121

Boris Bonev's avatar
Boris Bonev committed
122
123
    def forward(self, x):

124
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
125

Boris Bonev's avatar
Boris Bonev committed
126
        x = self.norm(x)
127

Boris Bonev's avatar
Boris Bonev committed
128
        if hasattr(self, "inner_skip"):
129
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
130

Boris Bonev's avatar
Boris Bonev committed
131
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
132
133
134
135
            x = self.mlp(x)

        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
136
        if hasattr(self, "outer_skip"):
137
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
138

Boris Bonev's avatar
Boris Bonev committed
139
140
        return x

141

Boris Bonev's avatar
Boris Bonev committed
142
class SphericalFourierNeuralOperator(nn.Module):
Boris Bonev's avatar
Boris Bonev committed
143
    """
144
145
146
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    Parameters
    ----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
167
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
168
169
170
171
172
173
174
175
176
177
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
Boris Bonev's avatar
Boris Bonev committed
178
    residual_prediction : bool, optional
Boris Bonev's avatar
Boris Bonev committed
179
180
181
        Whether to add a single large skip connection, by default True
    pos_embed : bool, optional
        Whether to use positional embedding, by default True
Boris Bonev's avatar
Boris Bonev committed
182
183
    bias : bool, optional
        Whether to use a bias, by default False
Boris Bonev's avatar
Boris Bonev committed
184
185
186

    Example:
    --------
Boris Bonev's avatar
Boris Bonev committed
187
    >>> model = SphericalFourierNeuralOperator(
Boris Bonev's avatar
Boris Bonev committed
188
189
190
191
192
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
193
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
194
195
196
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
197
198
199
200
201
202

    References
    -----------
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
203
204
205
    """

    def __init__(
206
207
208
        self,
        img_size=(128, 256),
        grid="equiangular",
209
        grid_internal="legendre-gauss",
210
211
212
213
214
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
215
        activation_function="gelu",
216
217
218
219
220
221
222
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
223
224
225
        residual_prediction=False,
        pos_embed="none",
        bias=False,
226
    ):
Boris Bonev's avatar
Boris Bonev committed
227

228
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
229
230

        self.img_size = img_size
231
        self.grid = grid
232
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
233
234
235
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
236
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
237
238
239
240
241
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
Boris Bonev's avatar
Boris Bonev committed
242
        self.residual_prediction = residual_prediction
Boris Bonev's avatar
Boris Bonev committed
243
244

        # activation function
Boris Bonev's avatar
Boris Bonev committed
245
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
246
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
247
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
248
            self.activation_function = nn.GELU
249
250
251
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
252
253
254
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

255
        # compute downsampled image size. We assume that the latitude-grid includes both poles
256
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
257
258
259
        self.w = self.img_size[1] // scale_factor

        # dropout
260
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
261
262
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
263
264
265
266
267
268
269
270
271
272
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
273
        else:
Boris Bonev's avatar
Boris Bonev committed
274
            raise ValueError(f"Unknown position embedding type {pos_embed}")
Boris Bonev's avatar
Boris Bonev committed
275

276
277
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
278
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
279
280
        current_dim = self.in_chans
        encoder_layers = []
281
        for l in range(num_encoder_layers - 1):
282
283
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
284
285
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
286
287
288
289
290
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
291
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=bias)
292
293
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
294
295
296
297
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
298

299
300
301
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
302
        modes_lon = (self.w // 2 + 1) - 1
Boris Bonev's avatar
Boris Bonev committed
303

304
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
305

306
307
308
309
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
310
311
312
313
314

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
315
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
316

317
            block = SphericalFourierNeuralOperatorBlock(
318
319
                self.trans_down if first_layer else self.trans,
                self.itrans_up if last_layer else self.itrans,
320
321
322
323
324
325
                self.embed_dim,
                self.embed_dim,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
326
                norm_layer=self.normalization_layer,
327
                use_mlp=use_mlp,
Boris Bonev's avatar
Boris Bonev committed
328
                bias=bias,
329
            )
Boris Bonev's avatar
Boris Bonev committed
330
331
332

            self.blocks.append(block)

333
334
335
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
Boris Bonev's avatar
Boris Bonev committed
336
        current_dim = self.embed_dim
337
        decoder_layers = []
338
        for l in range(num_decoder_layers - 1):
339
340
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
341
342
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
343
344
345
346
347
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
Boris Bonev's avatar
Boris Bonev committed
348
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=bias)
349
350
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
351
352
353
354
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
355
356
357

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
358
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
359
360
361
362
363
364
365

    def forward_features(self, x):

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
366

Boris Bonev's avatar
Boris Bonev committed
367
368
369
370
        return x

    def forward(self, x):

Boris Bonev's avatar
Boris Bonev committed
371
        if self.residual_prediction:
Boris Bonev's avatar
Boris Bonev committed
372
373
374
375
376
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
377
            x = self.pos_embed(x)
Boris Bonev's avatar
Boris Bonev committed
378
379
380
381
382

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
383
384
385
        if self.residual_prediction:
            x = x + residual

Boris Bonev's avatar
Boris Bonev committed
386
        return x