sfno.py 17.8 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn

from torch_harmonics import *

37
from .layers import *
Boris Bonev's avatar
Boris Bonev committed
38

Srikumar Sastry's avatar
Srikumar Sastry committed
39
40
from functools import partial

41

Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
46
47
48
49
50
class SpectralFilterLayer(nn.Module):
    """
    Fourier layer. Contains the convolution part of the FNO/SFNO
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
51
52
        input_dim,
        output_dim,
53
54
55
56
57
58
59
60
        gain=2.0,
        operator_type="diagonal",
        hidden_size_factor=2,
        factorization=None,
        separable=False,
        rank=1e-2,
        bias=True,
    ):
61
        super(SpectralFilterLayer, self).__init__()
Boris Bonev's avatar
Boris Bonev committed
62

63
        if factorization is None:
64
            self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
65

66
        elif factorization is not None:
67
68
69
70
71
72
73
74
75
76
77
78
            self.filter = FactorizedSpectralConvS2(
                forward_transform,
                inverse_transform,
                input_dim,
                output_dim,
                gain=gain,
                operator_type=operator_type,
                rank=rank,
                factorization=factorization,
                separable=separable,
                bias=bias,
            )
Boris Bonev's avatar
Boris Bonev committed
79
80

        else:
81
            raise (NotImplementedError)
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85

    def forward(self, x):
        return self.filter(x)

86

Boris Bonev's avatar
Boris Bonev committed
87
88
89
90
class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
91

Boris Bonev's avatar
Boris Bonev committed
92
    def __init__(
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        operator_type="driscoll-healy",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
        act_layer=nn.ReLU,
        norm_layer=nn.Identity,
        factorization=None,
        separable=False,
        rank=128,
        inner_skip="linear",
        outer_skip=None,
        use_mlp=True,
    ):
Boris Bonev's avatar
Boris Bonev committed
111
        super(SphericalFourierNeuralOperatorBlock, self).__init__()
112
113
114
115
116
117
118
119

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
120

Boris Bonev's avatar
Boris Bonev committed
121
        # convolution layer
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.filter = SpectralFilterLayer(
            forward_transform,
            inverse_transform,
            input_dim,
            output_dim,
            gain=gain_factor,
            operator_type=operator_type,
            hidden_size_factor=mlp_ratio,
            factorization=factorization,
            separable=separable,
            rank=rank,
            bias=True,
        )
Boris Bonev's avatar
Boris Bonev committed
135

Boris Bonev's avatar
Boris Bonev committed
136
        if inner_skip == "linear":
137
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
138
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
139
        elif inner_skip == "identity":
140
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
141
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
142
143
144
145
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
146

147
        self.act_layer = act_layer()
Boris Bonev's avatar
Boris Bonev committed
148
149
150

        # first normalisation layer
        self.norm0 = norm_layer()
151

Boris Bonev's avatar
Boris Bonev committed
152
        # dropout
153
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
154
155
156

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
157
            gain_factor /= 2.0
158

Boris Bonev's avatar
Boris Bonev committed
159
        if use_mlp == True:
160
            mlp_hidden_dim = int(output_dim * mlp_ratio)
161
162
163
            self.mlp = MLP(
                in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
            )
Boris Bonev's avatar
Boris Bonev committed
164

Boris Bonev's avatar
Boris Bonev committed
165
        if outer_skip == "linear":
166
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
167
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
Boris Bonev's avatar
Boris Bonev committed
168
        elif outer_skip == "identity":
169
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
170
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
171
172
173
174
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
175

Boris Bonev's avatar
Boris Bonev committed
176
177
178
        # second normalisation layer
        self.norm1 = norm_layer()

Boris Bonev's avatar
Boris Bonev committed
179
180
181
182
    def forward(self, x):

        x, residual = self.filter(x)

183
184
        x = self.norm0(x)

Boris Bonev's avatar
Boris Bonev committed
185
        if hasattr(self, "inner_skip"):
186
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
187

Boris Bonev's avatar
Boris Bonev committed
188
        if hasattr(self, "act_layer"):
Boris Bonev's avatar
Boris Bonev committed
189
190
            x = self.act_layer(x)

Boris Bonev's avatar
Boris Bonev committed
191
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
192
193
            x = self.mlp(x)

194
195
        x = self.norm1(x)

Boris Bonev's avatar
Boris Bonev committed
196
197
        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
198
        if hasattr(self, "outer_skip"):
199
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
200

Boris Bonev's avatar
Boris Bonev committed
201
202
        return x

203

Boris Bonev's avatar
Boris Bonev committed
204
205
206
207
208
209
210
211
212
213
class SphericalFourierNeuralOperatorNet(nn.Module):
    """
    SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
    both linear and non-linear variants.

    Parameters
    ----------
    spectral_transform : str, optional
        Type of spectral transformation to use, by default "sht"
    operator_type : str, optional
Boris Bonev's avatar
Boris Bonev committed
214
        Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
Boris Bonev's avatar
Boris Bonev committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
232
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    big_skip : bool, optional
        Whether to add a single large skip connection, by default True
    rank : float, optional
        Rank of the approximation, by default 1.0
    factorization : Any, optional
        Type of factorization to use, by default None
    separable : bool, optional
        Whether to use separable convolutions, by default False
    rank : (int, Tuple[int]), optional
        If a factorization is used, which rank to use. Argument is passed to tensorly
    pos_embed : bool, optional
        Whether to use positional embedding, by default True

    Example:
    --------
    >>> model = SphericalFourierNeuralOperatorNet(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
264
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
265
266
267
268
269
270
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
    """

    def __init__(
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        self,
        spectral_transform="sht",
        operator_type="driscoll-healy",
        img_size=(128, 256),
        grid="equiangular",
        scale_factor=3,
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
        activation_function="relu",
        encoder_layers=1,
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
        hard_thresholding_fraction=1.0,
        use_complex_kernels=True,
        big_skip=False,
        factorization=None,
        separable=False,
        rank=128,
        pos_embed=False,
    ):
Boris Bonev's avatar
Boris Bonev committed
296
297
298
299
300
301

        super(SphericalFourierNeuralOperatorNet, self).__init__()

        self.spectral_transform = spectral_transform
        self.operator_type = operator_type
        self.img_size = img_size
302
        self.grid = grid
Boris Bonev's avatar
Boris Bonev committed
303
304
305
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
306
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
307
308
309
310
311
312
313
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
        self.big_skip = big_skip
        self.factorization = factorization
314
        self.separable = (separable,)
Boris Bonev's avatar
Boris Bonev committed
315
316
317
        self.rank = rank

        # activation function
Boris Bonev's avatar
Boris Bonev committed
318
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
319
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
320
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
321
            self.activation_function = nn.GELU
322
323
324
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329
330
331
332
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

        # compute downsampled image size
        self.h = self.img_size[0] // scale_factor
        self.w = self.img_size[1] // scale_factor

        # dropout
333
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
334
335
336
337
338
339
340
341
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

        # pick norm layer
        if self.normalization_layer == "layer_norm":
            norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
            norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
        elif self.normalization_layer == "instance_norm":
            norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
342
            norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
Boris Bonev's avatar
Boris Bonev committed
343
344
345
346
        elif self.normalization_layer == "none":
            norm_layer0 = nn.Identity
            norm_layer1 = norm_layer0
        else:
347
            raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
Boris Bonev's avatar
Boris Bonev committed
348

349
        if pos_embed == "latlon" or pos_embed == True:
Boris Bonev's avatar
Boris Bonev committed
350
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
351
352
353
354
355
356
357
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "lat":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "const":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
            nn.init.constant_(self.pos_embed, 0.0)
Boris Bonev's avatar
Boris Bonev committed
358
359
360
        else:
            self.pos_embed = None

361
362
        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
363
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
364
365
        current_dim = self.in_chans
        encoder_layers = []
366
        for l in range(num_encoder_layers - 1):
367
368
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
369
370
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
371
372
373
374
375
376
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
377
378
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
379
380
381
382
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
383

Boris Bonev's avatar
Boris Bonev committed
384
        # prepare the spectral transform
Boris Bonev's avatar
Boris Bonev committed
385
        if self.spectral_transform == "sht":
Boris Bonev's avatar
Boris Bonev committed
386
387

            modes_lat = int(self.h * self.hard_thresholding_fraction)
388
            modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
389
            modes_lat = modes_lon = min(modes_lat, modes_lon)
Boris Bonev's avatar
Boris Bonev committed
390

391
            self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
392
393
394
            self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
            self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
            self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
Boris Bonev's avatar
Boris Bonev committed
395

Boris Bonev's avatar
Boris Bonev committed
396
        elif self.spectral_transform == "fft":
Boris Bonev's avatar
Boris Bonev committed
397
398
399
400
401

            modes_lat = int(self.h * self.hard_thresholding_fraction)
            modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)

            self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
402
403
404
            self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
            self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
            self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
405

Boris Bonev's avatar
Boris Bonev committed
406
        else:
407
            raise (ValueError("Unknown spectral transform"))
Boris Bonev's avatar
Boris Bonev committed
408
409
410
411
412

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
413
            last_layer = i == self.num_layers - 1
Boris Bonev's avatar
Boris Bonev committed
414
415
416
417

            forward_transform = self.trans_down if first_layer else self.trans
            inverse_transform = self.itrans_up if last_layer else self.itrans

418
419
            inner_skip = "none"
            outer_skip = "identity"
Boris Bonev's avatar
Boris Bonev committed
420
421

            if first_layer:
Boris Bonev's avatar
Boris Bonev committed
422
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
423
            elif last_layer:
Boris Bonev's avatar
Boris Bonev committed
424
                norm_layer = norm_layer0
Boris Bonev's avatar
Boris Bonev committed
425
            else:
Boris Bonev's avatar
Boris Bonev committed
426
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            block = SphericalFourierNeuralOperatorBlock(
                forward_transform,
                inverse_transform,
                self.embed_dim,
                self.embed_dim,
                operator_type=self.operator_type,
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
                norm_layer=norm_layer,
                inner_skip=inner_skip,
                outer_skip=outer_skip,
                use_mlp=use_mlp,
                factorization=self.factorization,
                separable=self.separable,
                rank=self.rank,
            )
Boris Bonev's avatar
Boris Bonev committed
446
447
448

            self.blocks.append(block)

449
450
451
        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
452
        current_dim = self.embed_dim + self.big_skip * self.in_chans
453
        decoder_layers = []
454
        for l in range(num_decoder_layers - 1):
455
456
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
457
458
            scale = math.sqrt(2.0 / current_dim)
            nn.init.normal_(fc.weight, mean=0.0, std=scale)
459
460
461
462
463
464
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
465
466
        scale = math.sqrt(1.0 / current_dim)
        nn.init.normal_(fc.weight, mean=0.0, std=scale)
467
468
469
470
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
471
472
473

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
474
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
475
476
477
478
479
480
481

    def forward_features(self, x):

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
482

Boris Bonev's avatar
Boris Bonev committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        return x

    def forward(self, x):

        if self.big_skip:
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
            x = x + self.pos_embed

        x = self.forward_features(x)

        if self.big_skip:
            x = torch.cat((x, residual), dim=1)

        x = self.decoder(x)

        return x