lsno.py 17.9 KB
Newer Older
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
33
import math

34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
40
from torch_harmonics import ResampleS2
41

Boris Bonev's avatar
Boris Bonev committed
42
from torch_harmonics.examples.models._layers import MLP, SpectralConvS2, SequencePositionEmbedding, SpectralPositionEmbedding, LearnablePositionEmbedding
43
44
45

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
46
47
48
49
50
# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
    theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}

    return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
51
52
53
54

class DiscreteContinuousEncoder(nn.Module):
    def __init__(
        self,
55
        in_shape=(721, 1440),
56
57
58
59
60
        out_shape=(480, 960),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
61
62
        kernel_shape=(3, 3),
        basis_type="morlet",
63
64
65
66
67
68
69
70
71
        groups=1,
        bias=False,
    ):
        super().__init__()

        # set up local convolution
        self.conv = DiscreteContinuousConvS2(
            inp_chans,
            out_chans,
72
            in_shape=in_shape,
73
74
            out_shape=out_shape,
            kernel_shape=kernel_shape,
75
            basis_type=basis_type,
76
77
78
79
            grid_in=grid_in,
            grid_out=grid_out,
            groups=groups,
            bias=bias,
Boris Bonev's avatar
Boris Bonev committed
80
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        )

    def forward(self, x):
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
            x = self.conv(x)
            x = x.to(dtype=dtype)

        return x


class DiscreteContinuousDecoder(nn.Module):
    def __init__(
        self,
97
        in_shape=(480, 960),
98
99
100
101
102
        out_shape=(721, 1440),
        grid_in="equiangular",
        grid_out="equiangular",
        inp_chans=2,
        out_chans=2,
Boris Bonev's avatar
Boris Bonev committed
103
104
        kernel_shape=(3, 3),
        basis_type="morlet",
105
106
        groups=1,
        bias=False,
Boris Bonev's avatar
Boris Bonev committed
107
        upsample_sht=False,
108
109
110
    ):
        super().__init__()

111
112
113
114
115
116
117
        # set up upsampling
        if upsample_sht:
            self.sht = RealSHT(*in_shape, grid=grid_in).float()
            self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
            self.upsample = nn.Sequential(self.sht, self.isht)
        else:
            self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
118
119

        # set up DISCO convolution
120
        self.conv = DiscreteContinuousConvS2(
121
122
123
124
125
            inp_chans,
            out_chans,
            in_shape=out_shape,
            out_shape=out_shape,
            kernel_shape=kernel_shape,
126
            basis_type=basis_type,
127
128
129
130
            grid_in=grid_out,
            grid_out=grid_out,
            groups=groups,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
131
            theta_cutoff=_compute_cutoff_radius(in_shape[0], kernel_shape, basis_type),
132
133
134
135
136
137
138
        )

    def forward(self, x):
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
139
            x = self.upsample(x)
140
            x = self.conv(x)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            x = x.to(dtype=dtype)

        return x


class SphericalNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
        input_dim,
        output_dim,
        conv_type="local",
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path=0.0,
161
        act_layer=nn.GELU,
Boris Bonev's avatar
Boris Bonev committed
162
        norm_layer="none",
163
164
        inner_skip="none",
        outer_skip="identity",
165
        use_mlp=True,
Boris Bonev's avatar
Boris Bonev committed
166
167
168
        disco_kernel_shape=(3, 3),
        disco_basis_type="morlet",
        bias=False,
169
170
171
172
173
174
175
176
177
178
179
180
181
    ):
        super().__init__()

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        # convolution layer
        if conv_type == "local":
Boris Bonev's avatar
Boris Bonev committed
182
            theta_cutoff = 2.0 * _compute_cutoff_radius(forward_transform.nlat, disco_kernel_shape, disco_basis_type)
183
184
185
186
187
188
            self.local_conv = DiscreteContinuousConvS2(
                input_dim,
                output_dim,
                in_shape=(forward_transform.nlat, forward_transform.nlon),
                out_shape=(inverse_transform.nlat, inverse_transform.nlon),
                kernel_shape=disco_kernel_shape,
189
                basis_type=disco_basis_type,
190
191
                grid_in=forward_transform.grid,
                grid_out=inverse_transform.grid,
Boris Bonev's avatar
Boris Bonev committed
192
193
                bias=bias,
                theta_cutoff=theta_cutoff,
194
195
            )
        elif conv_type == "global":
Boris Bonev's avatar
Boris Bonev committed
196
            self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, bias=bias)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        else:
            raise ValueError(f"Unknown convolution type {conv_type}")

        if inner_skip == "linear":
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif inner_skip == "identity":
            assert input_dim == output_dim
            self.inner_skip = nn.Identity()
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")

Boris Bonev's avatar
Boris Bonev committed
211
212
213
214
215
216
217
218
219
        # normalisation layer
        if norm_layer == "layer_norm":
            self.norm = nn.LayerNorm(normalized_shape=(inverse_transform.nlat, inverse_transform.nlon), eps=1e-6)
        elif norm_layer == "instance_norm":
            self.norm = nn.InstanceNorm2d(num_features=output_dim, eps=1e-6, affine=True, track_running_stats=False)
        elif norm_layer == "none":
            self.norm = nn.Identity()
        else:
            raise NotImplementedError(f"Error, normalization {norm_layer} not implemented.")
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0

        if use_mlp == True:
            mlp_hidden_dim = int(output_dim * mlp_ratio)
            self.mlp = MLP(
                in_features=output_dim,
                out_features=input_dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                drop_rate=drop_rate,
                checkpointing=False,
                gain=gain_factor,
            )

        if outer_skip == "linear":
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
        elif outer_skip == "identity":
            assert input_dim == output_dim
            self.outer_skip = nn.Identity()
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")

    def forward(self, x):

        residual = x

        if hasattr(self, "global_conv"):
            x, _ = self.global_conv(x)
        elif hasattr(self, "local_conv"):
            x = self.local_conv(x)

Boris Bonev's avatar
Boris Bonev committed
260
        x = self.norm(x)
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        if hasattr(self, "inner_skip"):
            x = x + self.inner_skip(residual)

        if hasattr(self, "mlp"):
            x = self.mlp(x)

        x = self.drop_path(x)

        if hasattr(self, "outer_skip"):
            x = x + self.outer_skip(residual)

        return x


Boris Bonev's avatar
Boris Bonev committed
276
class LocalSphericalNeuralOperator(nn.Module):
277
    """
278
279
280
281
    LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
    operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
    Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
    as well as in the encoder and decoders.
282
283

    Parameters
284
    -----------
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    kernel_shape: tuple, int
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_kernel_shape : int, optional
        size of the encoder kernel
Boris Bonev's avatar
Boris Bonev committed
302
303
    filter_basis_type: Optional[str]: str, optional
        filter basis type
304
305
306
307
308
309
310
311
312
313
    use_mlp : int, optional
        Whether to use MLPs in the SFNO blocks, by default True
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
Boris Bonev's avatar
Boris Bonev committed
314
315
    sfno_block_frequency : int, optional
        Hopw often a (global) SFNO block is used, by default 2
316
317
318
319
320
321
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    big_skip : bool, optional
        Whether to add a single large skip connection, by default True
    pos_embed : bool, optional
        Whether to use positional embedding, by default True
322
323
    upsample_sht : bool, optional
        Use SHT upsampling if true, else linear interpolation
Boris Bonev's avatar
Boris Bonev committed
324
325
    bias : bool, optional
        Whether to use a bias, by default False
326

327
328
    Example
    -----------
329
    >>> model = LocalSphericalNeuralOperator(
330
331
332
333
334
335
336
337
338
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=4,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
339
340
341
342
343
344
345
346
347
348
349

    References
    -----------
    .. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
        "Neural Operators with Localized Integral and Differential Kernels" (2024).
        ICML 2024, https://arxiv.org/pdf/2402.16845.

    .. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
        "Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
        ICML 2023, https://arxiv.org/abs/2306.03838.

350
351
352
353
354
355
    """

    def __init__(
        self,
        img_size=(128, 256),
        grid="equiangular",
356
        grid_internal="legendre-gauss",
357
        scale_factor=3,
358
359
360
361
        in_chans=3,
        out_chans=3,
        embed_dim=256,
        num_layers=4,
362
        activation_function="gelu",
Boris Bonev's avatar
Boris Bonev committed
363
364
365
        kernel_shape=(3, 3),
        encoder_kernel_shape=(3, 3),
        filter_basis_type="morlet",
366
367
368
369
370
        use_mlp=True,
        mlp_ratio=2.0,
        drop_rate=0.0,
        drop_path_rate=0.0,
        normalization_layer="none",
Boris Bonev's avatar
Boris Bonev committed
371
        sfno_block_frequency=2,
372
        hard_thresholding_fraction=1.0,
Boris Bonev's avatar
Boris Bonev committed
373
374
        residual_prediction=False,
        pos_embed="none",
375
        upsample_sht=False,
Boris Bonev's avatar
Boris Bonev committed
376
        bias=False,
377
378
379
380
381
    ):
        super().__init__()

        self.img_size = img_size
        self.grid = grid
382
        self.grid_internal = grid_internal
383
384
385
386
387
388
389
390
391
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.encoder_kernel_shape = encoder_kernel_shape
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
Boris Bonev's avatar
Boris Bonev committed
392
        self.residual_prediction = residual_prediction
393
394
395
396
397
398
399
400
401
402
403
404

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

405
        # compute downsampled image size. We assume that the latitude-grid includes both poles
406
        self.h = (self.img_size[0] - 1) // scale_factor + 1
407
408
409
410
411
412
        self.w = self.img_size[1] // scale_factor

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

Boris Bonev's avatar
Boris Bonev committed
413
414
415
416
417
418
419
420
421
422
        if pos_embed == "sequence":
            self.pos_embed = SequencePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "spectral":
            self.pos_embed = SpectralPositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal)
        elif pos_embed == "learnable lat":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="lat")
        elif pos_embed == "learnable latlon":
            self.pos_embed = LearnablePositionEmbedding((self.h, self.w), num_chans=self.embed_dim, grid=grid_internal, embed_type="latlon")
        elif pos_embed == "none":
            self.pos_embed = nn.Identity()
423
        else:
Boris Bonev's avatar
Boris Bonev committed
424
            raise ValueError(f"Unknown position embedding type {pos_embed}")
425
426

        # encoder
Boris Bonev's avatar
Boris Bonev committed
427
428
429
        self.encoder = DiscreteContinuousEncoder(
            in_shape=self.img_size,
            out_shape=(self.h, self.w),
430
            grid_in=grid,
431
            grid_out=grid_internal,
Boris Bonev's avatar
Boris Bonev committed
432
433
434
435
436
            inp_chans=self.in_chans,
            out_chans=self.embed_dim,
            kernel_shape=self.encoder_kernel_shape,
            basis_type=filter_basis_type,
            groups=1,
437
438
439
            bias=False,
        )

440
441
442
443
444
445
        # compute the modes for the sht
        modes_lat = self.h
        # due to some spectral artifacts with cufft, we substract one mode here
        modes_lon = (self.w // 2 + 1) - 1

        modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
446

447
448
        self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
        self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
449
450
451
452
453
454
455
456
457

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            block = SphericalNeuralOperatorBlock(
                self.trans,
                self.itrans,
                self.embed_dim,
                self.embed_dim,
Boris Bonev's avatar
Boris Bonev committed
458
                conv_type="global" if i % sfno_block_frequency == (sfno_block_frequency-1) else "local",
459
460
461
462
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                drop_path=dpr[i],
                act_layer=self.activation_function,
Boris Bonev's avatar
Boris Bonev committed
463
                norm_layer=self.normalization_layer,
464
465
                use_mlp=use_mlp,
                disco_kernel_shape=kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
466
                disco_basis_type=filter_basis_type,
Boris Bonev's avatar
Boris Bonev committed
467
                bias=bias,
468
469
470
471
472
473
            )

            self.blocks.append(block)

        # decoder
        self.decoder = DiscreteContinuousDecoder(
474
            in_shape=(self.h, self.w),
475
            out_shape=self.img_size,
476
            grid_in=grid_internal,
477
478
479
480
            grid_out=grid,
            inp_chans=self.embed_dim,
            out_chans=self.out_chans,
            kernel_shape=self.encoder_kernel_shape,
Boris Bonev's avatar
Boris Bonev committed
481
            basis_type=filter_basis_type,
482
483
            groups=1,
            bias=False,
Boris Bonev's avatar
Boris Bonev committed
484
            upsample_sht=upsample_sht,
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token"}

    def forward_features(self, x):
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        return x

    def forward(self, x):
Boris Bonev's avatar
Boris Bonev committed
500
        if self.residual_prediction:
501
502
503
504
505
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
Boris Bonev's avatar
Boris Bonev committed
506
            x = self.pos_embed(x)
507
508
509
510
511

        x = self.forward_features(x)

        x = self.decoder(x)

Boris Bonev's avatar
Boris Bonev committed
512
        if self.residual_prediction:
513
514
515
            x = x + residual

        return x