sfno.py 15.2 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn

from torch_harmonics import *

Boris Bonev's avatar
Boris Bonev committed
37
from ._layers import *
Boris Bonev's avatar
Boris Bonev committed
38

Srikumar Sastry's avatar
Srikumar Sastry committed
39
40
from functools import partial

41

Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
46

Boris Bonev's avatar
Boris Bonev committed
47
    def __init__(
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        operator_type="driscoll-healy",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
        act_layer=nn.ReLU,
        norm_layer=nn.Identity,
        factorization=None,
        separable=False,
        rank=128,
        inner_skip="linear",
        outer_skip=None,
        use_mlp=True,
    ):
66
        super().__init__()
67
68
69
70
71
72
73
74

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
75

76
        self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
Boris Bonev's avatar
Boris Bonev committed
77

Boris Bonev's avatar
Boris Bonev committed
78
        if inner_skip == "linear":
79
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
80
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
81
        elif inner_skip == "identity":
82
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
83
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
84
85
86
87
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
88

Boris Bonev's avatar
Boris Bonev committed
89
90
        # first normalisation layer
        self.norm0 = norm_layer()
91

Boris Bonev's avatar
Boris Bonev committed
92
        # dropout
93
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
94
95
96

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
97
            gain_factor /= 2.0
98

Boris Bonev's avatar
Boris Bonev committed
99
        if use_mlp == True:
100
            mlp_hidden_dim = int(output_dim * mlp_ratio)
101
102
103
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
104

Boris Bonev's avatar
Boris Bonev committed
105
        if outer_skip == "linear":
106
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
107
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
108
        elif outer_skip == "identity":
109
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
110
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
111
112
113
114
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
115

Boris Bonev's avatar
Boris Bonev committed
116
117
118
        # second normalisation layer
        self.norm1 = norm_layer()

Boris Bonev's avatar
Boris Bonev committed
119
120
    def forward(self, x):

121
        x, residual = self.global_conv(x)
Boris Bonev's avatar
Boris Bonev committed
122

123
124
        x = self.norm0(x)

Boris Bonev's avatar
Boris Bonev committed
125
        if hasattr(self, "inner_skip"):
126
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
127

Boris Bonev's avatar
Boris Bonev committed
128
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
129
130
            x = self.mlp(x)

131
132
        x = self.norm1(x)

Boris Bonev's avatar
Boris Bonev committed
133
134
        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
135
        if hasattr(self, "outer_skip"):
136
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
137

Boris Bonev's avatar
Boris Bonev committed
138
139
        return x

140

Boris Bonev's avatar
Boris Bonev committed
141
142
class SphericalFourierNeuralOperatorNet(nn.Module):
    """
143
144
145
    SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
    as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
    and approximately equivariant architecture.
Boris Bonev's avatar
Boris Bonev committed
146
147
148
149
150

    Parameters
    ----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
Boris Bonev's avatar
Boris Bonev committed
151
152
    operator_type : str, optional
        Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
Boris Bonev's avatar
Boris Bonev committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
168
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    big_skip : bool, optional
        Whether to add a single large skip connection, by default True
    pos_embed : bool, optional
        Whether to use positional embedding, by default True

    Example:
    --------
    >>> model = SphericalFourierNeuralOperatorNet(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
192
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
193
194
195
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
196
197
198
199
200
201

    References
    -----------
    .. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.
Boris Bonev's avatar
Boris Bonev committed
202
203
204
    """

    def __init__(
205
206
        self,
        img_size=(128, 256),
Boris Bonev's avatar
Boris Bonev committed
207
        operator_type="driscoll-healy",
208
        grid="equiangular",
209
        grid_internal="legendre-gauss",
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
        activation_function="relu",
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
        use_complex_kernels=True,
        big_skip=False,
        pos_embed=False,
    ):
Boris Bonev's avatar
Boris Bonev committed
227

228
        super().__init__()
Boris Bonev's avatar
Boris Bonev committed
229
230
231

        self.operator_type = operator_type
        self.img_size = img_size
232
        self.grid = grid
233
        self.grid_internal = grid_internal
Boris Bonev's avatar
Boris Bonev committed
234
235
236
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
237
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
238
239
240
241
242
243
244
245
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
        self.big_skip = big_skip

        # activation function
Boris Bonev's avatar
Boris Bonev committed
246
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
247
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
248
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
249
            self.activation_function = nn.GELU
250
251
252
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
253
254
255
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

256
        # compute downsampled image size. We assume that the latitude-grid includes both poles
257
        self.h = (self.img_size[0] - 1) // scale_factor + 1
Boris Bonev's avatar
Boris Bonev committed
258
259
260
        self.w = self.img_size[1] // scale_factor

        # dropout
261
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
262
263
264
265
266
267
268
269
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

        # pick norm layer
        if self.normalization_layer == "layer_norm":
            norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
            norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
        elif self.normalization_layer == "instance_norm":
            norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
270
            norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
Boris Bonev's avatar
Boris Bonev committed
271
272
273
274
        elif self.normalization_layer == "none":
            norm_layer0 = nn.Identity
            norm_layer1 = norm_layer0
        else:
275
            raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
Boris Bonev's avatar
Boris Bonev committed
276

277
        if pos_embed == "latlon" or pos_embed == True:
Boris Bonev's avatar
Boris Bonev committed
278
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
279
280
281
282
283
284
285
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "lat":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "const":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
            nn.init.constant_(self.pos_embed, 0.0)
Boris Bonev's avatar
Boris Bonev committed
286
287
288
        else:
            self.pos_embed = None

289
290
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
291
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
292
293
        current_dim = self.in_chans
        encoder_layers = []
294
        for l in range(num_encoder_layers - 1):
295
296
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
297
298
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
299
300
301
302
303
304
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
305
306
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
307
308
309
310
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
311

312
313
314
315
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
        modes_lon = (self.w // 2 + 1) -1
Boris Bonev's avatar
Boris Bonev committed
316

317
        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
Boris Bonev's avatar
Boris Bonev committed
318

319
320
321
322
        self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
Boris Bonev's avatar
Boris Bonev committed
323
324
325
326
327

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
328
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
329
330
331
332

            forward_transform = self.trans_down if first_layer else self.trans
            inverse_transform = self.itrans_up if last_layer else self.itrans

333
334
            inner_skip = "none"
            outer_skip = "identity"
Boris Bonev's avatar
Boris Bonev committed
335
336

            if first_layer:
Boris Bonev's avatar
Boris Bonev committed
337
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
338
            elif last_layer:
Boris Bonev's avatar
Boris Bonev committed
339
                norm_layer = norm_layer0
Boris Bonev's avatar
Boris Bonev committed
340
            else:
Boris Bonev's avatar
Boris Bonev committed
341
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
342

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            block = SphericalFourierNeuralOperatorBlock(
                forward_transform,
                inverse_transform,
                self.embed_dim,
                self.embed_dim,
                operator_type=self.operator_type,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
                norm_layer=norm_layer,
                inner_skip=inner_skip,
                outer_skip=outer_skip,
                use_mlp=use_mlp,
            )
Boris Bonev's avatar
Boris Bonev committed
358
359
360

            self.blocks.append(block)

361
362
363
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
364
        current_dim = self.embed_dim + self.big_skip * self.in_chans
365
        decoder_layers = []
366
        for l in range(num_decoder_layers - 1):
367
368
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
369
370
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
371
372
373
374
375
376
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
377
378
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
379
380
381
382
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
383
384
385

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
386
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
387
388
389
390
391
392
393

    def forward_features(self, x):

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
394

Boris Bonev's avatar
Boris Bonev committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        return x

    def forward(self, x):

        if self.big_skip:
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
            x = x + self.pos_embed

        x = self.forward_features(x)

        if self.big_skip:
            x = torch.cat((x, residual), dim=1)

        x = self.decoder(x)

        return x