s2segformer.py 26.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math

import torch
import torch.nn as nn
import torch.amp as amp

from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import AttentionS2, NeighborhoodAttentionS2
from torch_harmonics import ResampleS2
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics.quadrature import _precompute_latitudes

from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath

from functools import partial


# heuristic for finding theta_cutoff
def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
    theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)}

    return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)


class OverlapPatchMerging(nn.Module):
apaaris's avatar
apaaris committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    """
    Overlap patch merging module for spherical segformer.
    
    This module performs patch merging with overlapping patches using discrete-continuous
    convolutions on the sphere, followed by layer normalization.
    
    Parameters
    -----------
    in_shape : tuple, optional
        Input shape (nlat, nlon), by default (721, 1440)
    out_shape : tuple, optional
        Output shape (nlat, nlon), by default (481, 960)
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    in_channels : int, optional
        Number of input channels, by default 3
    out_channels : int, optional
        Number of output channels, by default 64
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    bias : bool, optional
        Whether to use bias, by default False
    """
    
Boris Bonev's avatar
Boris Bonev committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def __init__(
        self,
        in_shape=(721, 1440),
        out_shape=(481, 960),
        grid_in="equiangular",
        grid_out="equiangular",
        in_channels=3,
        out_channels=64,
        kernel_shape=(3, 3),
        basis_type="morlet",
        bias=False,
    ):
        super().__init__()

        # convolution for patches, curtoff radius inferred from kernel shape
        theta_cutoff = _compute_cutoff_radius(out_shape[0], kernel_shape, basis_type)
        self.conv = DiscreteContinuousConvS2(
            in_channels,
            out_channels,
            in_shape=in_shape,
            out_shape=out_shape,
            kernel_shape=kernel_shape,
            basis_type=basis_type,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=bias,
            theta_cutoff=theta_cutoff,
        )

        # layer norm
        self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
120

Boris Bonev's avatar
Boris Bonev committed
121
122
123
124
125
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
126

Boris Bonev's avatar
Boris Bonev committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        dtype = x.dtype

        with amp.autocast(device_type="cuda", enabled=False):
            x = x.float()
            x = self.conv(x).to(dtype=dtype)

        # permute
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        out = x.permute(0, 3, 1, 2)

        return out


class MixFFN(nn.Module):
apaaris's avatar
apaaris committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    """
    Mix FFN module for spherical segformer.
    
    This module implements a feed-forward network that combines MLP operations
    with discrete-continuous convolutions on the sphere.
    
    Parameters
    -----------
    shape : tuple
        Shape (nlat, nlon) of the input
    inout_channels : int
        Number of input/output channels
    hidden_channels : int
        Number of hidden channels in MLP
    mlp_bias : bool, optional
        Whether to use bias in MLP, by default True
    grid : str, optional
        Grid type, by default "equiangular"
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    conv_bias : bool, optional
        Whether to use bias in convolution, by default False
    activation : nn.Module, optional
        Activation function, by default nn.GELU
    use_mlp : bool, optional
        Whether to use MLP instead of linear layers, by default False
    drop_path : float, optional
        Drop path rate, by default 0.0
    """
    
Boris Bonev's avatar
Boris Bonev committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def __init__(
        self,
        shape,
        inout_channels,
        hidden_channels,
        mlp_bias=True,
        grid="equiangular",
        kernel_shape=(3, 3),
        basis_type="morlet",
        conv_bias=False,
        activation=nn.GELU,
        use_mlp=False,
        drop_path=0.0,
    ):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)

        if use_mlp:
            # although the paper says MLP, it uses a single linear layer
            self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)

        # convolution for patches, curtoff radius inferred from kernel shape
        theta_cutoff = _compute_cutoff_radius(shape[0], kernel_shape, basis_type)
        self.conv = DiscreteContinuousConvS2(
            inout_channels,
            inout_channels,
            in_shape=shape,
            out_shape=shape,
            kernel_shape=kernel_shape,
            basis_type=basis_type,
            grid_in=grid,
            grid_out=grid,
            groups=inout_channels,
            bias=conv_bias,
            theta_cutoff=theta_cutoff,
        )

        if use_mlp:
            self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)

        self.act = activation()

        self.apply(self._init_weights)

    def _init_weights(self, m):
226

Boris Bonev's avatar
Boris Bonev committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        # norm
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)

        # NOTE: we add another activation here
        # because in the paper they only use depthwise conv,
        # but without this activation it would just be a fused MM
        # with the disco conv
        x = self.mlp_in(x)

        # conv parth
        x = self.act(self.conv(x))

        # second linear
        x = self.mlp_out(x)

        return residual + self.drop_path(x)


class AttentionWrapper(nn.Module):
apaaris's avatar
apaaris committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    """
    Attention wrapper for spherical segformer.
    
    This module wraps attention mechanisms (neighborhood or global) with optional
    normalization and drop path regularization.
    
    Parameters
    -----------
    channels : int
        Number of channels
    shape : tuple
        Shape (nlat, nlon) of the input
    grid : str
        Grid type
    heads : int
        Number of attention heads
    pre_norm : bool, optional
        Whether to apply normalization before attention, by default False
    attention_drop_rate : float, optional
        Dropout rate for attention, by default 0.0
    drop_path : float, optional
        Drop path rate, by default 0.0
    attention_mode : str, optional
        Attention mode ("neighborhood" or "global"), by default "neighborhood"
    theta_cutoff : float, optional
        Cutoff radius for neighborhood attention, by default None
    bias : bool, optional
        Whether to use bias, by default True
    """
Boris Bonev's avatar
Boris Bonev committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    def __init__(
        self,
        channels,
        shape,
        grid,
        heads,
        pre_norm=False,
        attention_drop_rate=0.0,
        drop_path=0.0,
        attention_mode="neighborhood",
        theta_cutoff=None,
        bias=True
    ):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.attention_mode = attention_mode

        if attention_mode == "neighborhood":
            if theta_cutoff is None:
                theta_cutoff = (7.0 / math.sqrt(math.pi)) * math.pi / (shape[0] - 1)
            self.att = NeighborhoodAttentionS2(
                in_channels=channels,
                in_shape=shape,
                out_shape=shape,
                grid_in=grid,
                grid_out=grid,
                theta_cutoff=theta_cutoff,
                out_channels=channels,
                num_heads=heads,
                bias=bias
                # drop_rate=attention_drop_rate,
            )
        else:
            self.att = AttentionS2(
                in_channels=channels,
                num_heads=heads,
                in_shape=shape,
                out_shape=shape,
                grid_in=grid,
                grid_out=grid,
                out_channels=channels,
                drop_rate=attention_drop_rate,
                bias=bias
            )

        self.norm = None
        if pre_norm:
            self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
346

Boris Bonev's avatar
Boris Bonev committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        residual = x
        if self.norm is not None:
            x = x.permute(0, 2, 3, 1)
            x = self.norm(x)
            x = x.permute(0, 3, 1, 2)

        if self.attention_mode == "neighborhood":
            dtype = x.dtype
            with amp.autocast(device_type="cuda", enabled=False):
                x = x.float()
                x = self.att(x).to(dtype=dtype)
        else:
            x = self.att(x)

        return residual + self.drop_path(x)


class TransformerBlock(nn.Module):
apaaris's avatar
apaaris committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    """
    Transformer block for spherical segformer.
    
    This block combines patch merging, attention, and Mix FFN operations
    in a hierarchical structure for processing spherical data.
    
    Parameters
    -----------
    in_shape : tuple
        Input shape (nlat, nlon)
    out_shape : tuple
        Output shape (nlat, nlon)
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    mlp_hidden_channels : int
        Number of hidden channels in MLP
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    nrep : int, optional
        Number of repetitions, by default 1
    heads : int, optional
        Number of attention heads, by default 1
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    activation : nn.Module, optional
        Activation function, by default nn.GELU
    att_drop_rate : float, optional
        Dropout rate for attention, by default 0.0
    drop_path_rates : float, optional
        Drop path rates, by default 0.0
    attention_mode : str, optional
        Attention mode ("neighborhood" or "global"), by default "neighborhood"
    theta_cutoff : float, optional
        Cutoff radius for neighborhood attention, by default None
    bias : bool, optional
        Whether to use bias, by default True
    """
Boris Bonev's avatar
Boris Bonev committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def __init__(
        self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        mlp_hidden_channels,
        grid_in="equiangular",
        grid_out="equiangular",
        nrep=1,
        heads=1,
        kernel_shape=(3, 3),
        basis_type="morlet",
        activation=nn.GELU,
        att_drop_rate=0.0,
        drop_path_rates=0.0,
        attention_mode="neighborhood",
        theta_cutoff=None,
        bias=True
    ):
        super().__init__()

        self.in_shape = in_shape
        self.out_shape = out_shape
        self.in_channels = in_channels
        self.out_channels = out_channels

        if isinstance(drop_path_rates, float):
            drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]

        assert len(drop_path_rates) == nrep

        self.fwd = [
            OverlapPatchMerging(
                in_shape=in_shape,
                out_shape=out_shape,
                grid_in=grid_in,
                grid_out=grid_out,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_shape=kernel_shape,
                basis_type=basis_type,
                bias=False,
            )
        ]

        for i in range(nrep):
            self.fwd.append(
                AttentionWrapper(
                    channels=out_channels,
                    shape=out_shape,
                    grid=grid_out,
                    heads=heads,
                    pre_norm=True,
                    attention_drop_rate=att_drop_rate,
                    drop_path=drop_path_rates[i],
                    attention_mode=attention_mode,
                    theta_cutoff=theta_cutoff,
                    bias=bias
                )
            )

            self.fwd.append(
                MixFFN(
                    out_shape,
                    inout_channels=out_channels,
                    hidden_channels=mlp_hidden_channels,
                    mlp_bias=True,
                    grid=grid_out,
                    kernel_shape=kernel_shape,
                    basis_type=basis_type,
                    conv_bias=False,
                    activation=activation,
                    use_mlp=False,
                    drop_path=drop_path_rates[i],
                )
            )

        # make sequential
        self.fwd = nn.Sequential(*self.fwd)

        # final norm
        self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fwd(x)

        # apply norm
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)

        return x


class Upsampling(nn.Module):
apaaris's avatar
apaaris committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    """
    Upsampling module for spherical segformer.
    
    This module performs upsampling using either discrete-continuous transposed convolutions
    or bilinear resampling on spherical data.
    
    Parameters
    -----------
    in_shape : tuple
        Input shape (nlat, nlon)
    out_shape : tuple
        Output shape (nlat, nlon)
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    hidden_channels : int
        Number of hidden channels in MLP
    mlp_bias : bool, optional
        Whether to use bias in MLP, by default True
    grid_in : str, optional
        Input grid type, by default "equiangular"
    grid_out : str, optional
        Output grid type, by default "equiangular"
    kernel_shape : tuple, optional
        Kernel shape for convolution, by default (3, 3)
    basis_type : str, optional
        Filter basis type, by default "morlet"
    conv_bias : bool, optional
        Whether to use bias in convolution, by default False
    activation : nn.Module, optional
        Activation function, by default nn.GELU
    use_mlp : bool, optional
        Whether to use MLP instead of linear layers, by default False
    upsampling_method : str, optional
        Upsampling method ("conv" or "bilinear"), by default "conv"
    """
Boris Bonev's avatar
Boris Bonev committed
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
    def __init__(
        self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        hidden_channels,
        mlp_bias=True,
        grid_in="equiangular",
        grid_out="equiangular",
        kernel_shape=(3, 3),
        basis_type="morlet",
        conv_bias=False,
        activation=nn.GELU,
        use_mlp=False,
        upsampling_method="conv"
    ):
        super().__init__()

        if use_mlp:
            self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)

        if upsampling_method == "conv":
            theta_cutoff = _compute_cutoff_radius(in_shape[0], kernel_shape, basis_type)
            self.upsample = DiscreteContinuousConvTransposeS2(
                out_channels,
                out_channels,
                in_shape=in_shape,
                out_shape=out_shape,
                kernel_shape=kernel_shape,
                basis_type=basis_type,
                grid_in=grid_in,
                grid_out=grid_out,
                bias=conv_bias,
                theta_cutoff=theta_cutoff,
            )
        elif upsampling_method == "bilinear":
            self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
        else:
            raise ValueError(f"Unknown upsampling method {upsampling_method}")

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
603

Boris Bonev's avatar
Boris Bonev committed
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        x = self.upsample(self.mlp(x))

        return x


class SphericalSegformer(nn.Module):
    """
    Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks

    Parameters
    -----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    kernel_shape: tuple, int
    scale_factor: int, optional
        Scale factor to use, by default 2
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of classes, by default 3
    embed_dims : List[int], optional
        Dimension of the embeddings for each block, has to be the same length as heads
    heads : List[int], optional
        Number of heads for each block in the network, has to be the same length as embed_dims
    depths: List[in], optional
        Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
    activation_function : str, optional
        Activation function to use, by default "gelu"
    embedder_kernel_shape : int, optional
        size of the encoder kernel
    filter_basis_type: Optional[str]: str, optional
        filter basis type
    use_mlp : int, optional
        Whether to use MLPs in the SFNO blocks, by default True
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    upsampling_method : str
        Conv, bilinear

    Example
    -----------
    >>> model = SphericalTransformer(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
    ...         num_layers=4,
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
    """

    def __init__(
        self,
        img_size=(128, 256),
        grid="equiangular",
        grid_internal="legendre-gauss",
        in_chans=3,
        out_chans=3,
        embed_dims=[64, 128, 256, 512],
        heads=[1, 2, 4, 8],
        depths=[3, 4, 6, 3],
        scale_factor=2,
        activation_function="gelu",
        kernel_shape=(3, 3),
        filter_basis_type="morlet",
        mlp_ratio=2.0,
        att_drop_rate=0.0,
        drop_path_rate=0.1,
        attention_mode="neighborhood",
        theta_cutoff=None,
        upsampling_method="bilinear",
        bias=True
    ):
        super().__init__()

        self.img_size = img_size
        self.grid = grid
        self.grid_internal = grid_internal
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dims = embed_dims
        self.heads = heads
        self.num_blocks = len(self.embed_dims)
        self.depths = depths
        self.kernel_shape = kernel_shape

        assert len(self.heads) == self.num_blocks
        assert len(self.depths) == self.num_blocks

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

        # set up drop path rates
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]

        self.blocks = nn.ModuleList([])
        out_shape = img_size
        grid_in = grid
        grid_out = grid_internal
        in_channels = in_chans
        cur = 0
        for i in range(self.num_blocks):
            out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
            out_channels = self.embed_dims[i]
            self.blocks.append(
                TransformerBlock(
                    in_shape=out_shape,
                    out_shape=out_shape_new,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    mlp_hidden_channels=int(mlp_ratio * out_channels),
                    grid_in=grid_in,
                    grid_out=grid_out,
                    nrep=self.depths[i],
                    heads=self.heads[i],
                    kernel_shape=kernel_shape,
                    basis_type=filter_basis_type,
                    activation=self.activation_function,
                    att_drop_rate=att_drop_rate,
                    drop_path_rates=dpr[cur : cur + self.depths[i]],
                    attention_mode=attention_mode,
                    theta_cutoff=theta_cutoff,
                    bias=bias
                )
            )
            cur += self.depths[i]
            out_shape = out_shape_new
            grid_in = grid_internal
            in_channels = out_channels

        self.upsamplers = nn.ModuleList([])
        out_shape = img_size
        grid_out = grid
        for i in range(self.num_blocks):
            in_shape = self.blocks[i].out_shape
            self.upsamplers.append(
                Upsampling(
                    in_shape=in_shape,
                    out_shape=out_shape,
                    in_channels=self.embed_dims[i],
                    out_channels=self.embed_dims[i],
                    hidden_channels=int(mlp_ratio * self.embed_dims[i]),
                    mlp_bias=True,
                    grid_in=grid_internal,
                    grid_out=grid,
                    kernel_shape=kernel_shape,
                    basis_type=filter_basis_type,
                    conv_bias=False,
                    activation=nn.GELU,
                    upsampling_method=upsampling_method
                )
            )

        segmentation_head_dim = sum(self.embed_dims)
        self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
781

Boris Bonev's avatar
Boris Bonev committed
782
783
784
785
786
787
788
789
790
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
791

Boris Bonev's avatar
Boris Bonev committed
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
        # encoder:
        features = []
        feat = x
        for block in self.blocks:
            feat = block(feat)
            features.append(feat)

        # perform upsample
        upfeats = []
        for feat, upsampler in zip(features, self.upsamplers):
            upfeats.append(upsampler(feat))

        # perform concatenation
        upfeats = torch.cat(upfeats, dim=1)

        # final upsampling and prediction
        out = self.segmentation_head(upfeats)

        return out