sfno.py 14.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
34
from torch_harmonics import RealSHT, InverseRealSHT
Boris Bonev's avatar
Boris Bonev committed
35

Boris Bonev's avatar
Boris Bonev committed
36
from ._layers import *
Boris Bonev's avatar
Boris Bonev committed
37

Srikumar Sastry's avatar
Srikumar Sastry committed
38
39
from functools import partial

40

Boris Bonev's avatar
Boris Bonev committed
41
42
43
44
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
45

Boris Bonev's avatar
Boris Bonev committed
46
    def __init__(
47
48
49
50
51
52
53
54
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
55
        act_layer=nn.GELU,
56
        norm_layer=nn.Identity,
57
58
        inner_skip="none",
        outer_skip="identity",
59
60
        use_mlp=True,
    ):
61
        super().__init__()
62
63
64
65
66
67
68
69

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
70

71
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=False)
Boris Bonev's avatar
Boris Bonev committed
72

Boris Bonev's avatar
Boris Bonev committed
73
        if inner_skip == "linear":
74
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
75
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
76
        elif inner_skip == "identity":
77
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
78
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
79
80
81
82
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
83

Boris Bonev's avatar
Boris Bonev committed
84
85
        # first normalisation layer
        self.norm0 = norm_layer()
86

Boris Bonev's avatar
Boris Bonev committed
87
        # dropout
88
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
89
90
91

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
92
            gain_factor /= 2.0
93

Boris Bonev's avatar
Boris Bonev committed
94
        if use_mlp == True:
95
            mlp_hidden_dim = int(output_dim * mlp_ratio)
96
97
98
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
99

Boris Bonev's avatar
Boris Bonev committed
100
        if outer_skip == "linear":
101
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
102
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
103
        elif outer_skip == "identity":
104
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
105
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
106
107
108
109
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
110

Boris Bonev's avatar
Boris Bonev committed
111
112
113
        # second normalisation layer
        self.norm1 = norm_layer()

Boris Bonev's avatar
Boris Bonev committed
114
115
    def forward(self, x):

116
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
117

118
119
        x = self.norm0(x)

Boris Bonev's avatar
Boris Bonev committed
120
        if hasattr(self, "inner_skip"):
121
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
122

Boris Bonev's avatar
Boris Bonev committed
123
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
124
125
            x = self.mlp(x)

126
127
        x = self.norm1(x)

Boris Bonev's avatar
Boris Bonev committed
128
129
        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
130
        if hasattr(self, "outer_skip"):
131
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
132

Boris Bonev's avatar
Boris Bonev committed
133
134
        return x

135

Boris Bonev's avatar
Boris Bonev committed
136
137
class SphericalFourierNeuralOperatorNet(nn.Module):
    """
138
139
140
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    Parameters
    ----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
161
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    big_skip : bool, optional
        Whether to add a single large skip connection, by default True
    pos_embed : bool, optional
        Whether to use positional embedding, by default True

    Example:
    --------
    >>> model = SphericalFourierNeuralOperatorNet(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
185
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
186
187
188
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
189
190
191
192
193
194

    References
    -----------
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
195
196
197
    """

    def __init__(
198
199
200
        self,
        img_size=(128, 256),
        grid="equiangular",
201
        grid_internal="legendre-gauss",
202
203
204
205
206
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
207
        activation_function="gelu",
208
209
210
211
212
213
214
215
216
217
218
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
        use_complex_kernels=True,
        big_skip=False,
        pos_embed=False,
    ):
Boris Bonev's avatar
Boris Bonev committed
219

220
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
221
222

        self.img_size = img_size
223
        self.grid = grid
224
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
225
226
227
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
228
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
229
230
231
232
233
234
235
236
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
        self.big_skip = big_skip

        # activation function
Boris Bonev's avatar
Boris Bonev committed
237
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
238
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
239
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
240
            self.activation_function = nn.GELU
241
242
243
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
244
245
246
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

247
        # compute downsampled image size. We assume that the latitude-grid includes both poles
248
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
249
250
251
        self.w = self.img_size[1] // scale_factor

        # dropout
252
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
253
254
255
256
257
258
259
260
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

        # pick norm layer
        if self.normalization_layer == "layer_norm":
            norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
            norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
        elif self.normalization_layer == "instance_norm":
            norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
261
            norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
Boris Bonev's avatar
Boris Bonev committed
262
263
264
265
        elif self.normalization_layer == "none":
            norm_layer0 = nn.Identity
            norm_layer1 = norm_layer0
        else:
266
            raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
Boris Bonev's avatar
Boris Bonev committed
267

268
        if pos_embed == "latlon" or pos_embed == True:
Boris Bonev's avatar
Boris Bonev committed
269
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
270
271
272
273
274
275
276
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "lat":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "const":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
            nn.init.constant_(self.pos_embed, 0.0)
Boris Bonev's avatar
Boris Bonev committed
277
278
279
        else:
            self.pos_embed = None

280
281
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
282
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
283
284
        current_dim = self.in_chans
        encoder_layers = []
285
        for l in range(num_encoder_layers - 1):
286
287
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
288
289
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
290
291
292
293
294
295
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
296
297
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
298
299
300
301
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
302

303
304
305
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
306
        modes_lon = (self.w // 2 + 1) - 1
Boris Bonev's avatar
Boris Bonev committed
307

308
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
309

310
311
312
313
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
314
315
316
317
318

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
319
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
320
321

            if first_layer:
Boris Bonev's avatar
Boris Bonev committed
322
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
323
            elif last_layer:
Boris Bonev's avatar
Boris Bonev committed
324
                norm_layer = norm_layer0
Boris Bonev's avatar
Boris Bonev committed
325
            else:
Boris Bonev's avatar
Boris Bonev committed
326
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
327

328
            block = SphericalFourierNeuralOperatorBlock(
329
330
                self.trans_down if first_layer else self.trans,
                self.itrans_up if last_layer else self.itrans,
331
332
333
334
335
336
337
338
339
                self.embed_dim,
                self.embed_dim,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
                norm_layer=norm_layer,
                use_mlp=use_mlp,
            )
Boris Bonev's avatar
Boris Bonev committed
340
341
342

            self.blocks.append(block)

343
344
345
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
346
        current_dim = self.embed_dim + self.big_skip * self.in_chans
347
        decoder_layers = []
348
        for l in range(num_decoder_layers - 1):
349
350
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
351
352
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
353
354
355
356
357
358
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
359
360
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
361
362
363
364
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
365
366
367

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
368
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
369
370
371
372
373
374
375

    def forward_features(self, x):

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
376

Boris Bonev's avatar
Boris Bonev committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        return x

    def forward(self, x):

        if self.big_skip:
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
            x = x + self.pos_embed

        x = self.forward_features(x)

        if self.big_skip:
            x = torch.cat((x, residual), dim=1)

        x = self.decoder(x)

        return x