segformer.py 19 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math

import torch
import torch.nn as nn

from natten import NeighborhoodAttention2D as NeighborhoodAttention
from torch_harmonics.examples.models._layers import MLP, LayerNorm, DropPath

from functools import partial


class OverlapPatchMerging(nn.Module):
    def __init__(
        self,
        in_shape=(721, 1440),
        out_shape=(481, 960),
        in_channels=3,
        out_channels=64,
        kernel_shape=(3, 3),
        bias=False,
    ):
        super().__init__()

        # conv
        stride_h = in_shape[0] // out_shape[0]
        stride_w = in_shape[1] // out_shape[1]
        pad_h = math.ceil(((out_shape[0] - 1) * stride_h - in_shape[0] + kernel_shape[0]) / 2)
        pad_w = math.ceil(((out_shape[1] - 1) * stride_w - in_shape[1] + kernel_shape[1]) / 2)
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_shape,
            bias=bias,
            stride=(stride_h, stride_w),
            padding=(pad_h, pad_w),
        )

        # layer norm
        self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.conv(x)

        # permute
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        out = x.permute(0, 3, 1, 2)

        return out


class MixFFN(nn.Module):
    def __init__(
        self,
        shape,
        inout_channels,
        hidden_channels,
        mlp_bias=True,
        kernel_shape=(3, 3),
        conv_bias=False,
        activation=nn.GELU,
        use_mlp=False,
        drop_path=0.0,
    ):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm = nn.LayerNorm((inout_channels), eps=1e-05, elementwise_affine=True, bias=True)

        if use_mlp:
            # although the paper says MLP, it uses a single linear layer
            self.mlp_in = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp_in = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)

        self.conv = nn.Conv2d(inout_channels, inout_channels, kernel_size=kernel_shape, groups=inout_channels, bias=conv_bias, padding="same")

        if use_mlp:
            self.mlp_out = MLP(inout_channels, hidden_features=hidden_channels, out_features=inout_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp_out = nn.Conv2d(in_channels=inout_channels, out_channels=inout_channels, kernel_size=1, bias=True)

        self.act = activation()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        residual = x

        # norm
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)

        # NOTE: we add another activation here
        # because in the paper they only use depthwise conv,
        # but without this activation it would just be a fused MM
        # with the disco conv
        x = self.mlp_in(x)

        # conv parth
        x = self.act(self.conv(x))

        # second linear
        x = self.mlp_out(x)

        return residual + self.drop_path(x)


class GlobalAttention(nn.Module):
    """
    Global self-attention block over 2D inputs using MultiheadAttention.

    Input shape: (B, C, H, W)
    Output shape: (B, C, H, W) with residual skip.
    """

    def __init__(self, chans, num_heads=8, dropout=0.0, bias=True):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=chans, num_heads=num_heads, dropout=dropout, batch_first=True, bias=bias)

    def forward(self, x):
        # x: B, C, H, W
        B, H, W, C = x.shape
        # flatten spatial dims
        x_flat = x.view(B, H * W, C)  # B, N, C
        # self-attention
        out, _ = self.attn(x_flat, x_flat, x_flat)
        # reshape back
        out = out.view(B, H, W, C)
        return out


class AttentionWrapper(nn.Module):
    def __init__(self, channels, shape, heads, pre_norm=False, attention_drop_rate=0.0, drop_path=0.0, attention_mode="neighborhood", kernel_shape=(7, 7), bias=True):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.attention_mode = attention_mode

        if attention_mode == "neighborhood":
            self.att = NeighborhoodAttention(
                channels, kernel_size=kernel_shape, dilation=1, num_heads=heads, qk_scale=None, attn_drop=attention_drop_rate, proj_drop=0.0, qkv_bias=bias
            )
        elif attention_mode == "global":
            self.att = GlobalAttention(channels, num_heads=heads, dropout=attention_drop_rate, bias=bias)
        else:
            raise ValueError(f"Unknown attention mode function {attention_mode}")

        self.norm = None
        if pre_norm:
            self.norm = nn.LayerNorm((channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = x.permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)

        x = self.att(x)
        x = x.permute(0, 3, 1, 2)
        return residual + self.drop_path(x)


class TransformerBlock(nn.Module):
    def __init__(
        self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        mlp_hidden_channels,
        nrep=1,
        heads=1,
        kernel_shape=(3, 3),
        activation=nn.GELU,
        att_drop_rate=0.0,
        drop_path_rates=0.0,
        attention_mode="neighborhood",
        attn_kernel_shape=(7, 7),
        bias=True
    ):
        super().__init__()

        # ensure odd
        if attn_kernel_shape[0] % 2 == 0:
            raise ValueError(f"Attn Kernel shape {kernel_shape} is even, use odd kernel shape")
        if attn_kernel_shape[1] % 2 == 0:
            raise ValueError(f"Kernel shape {kernel_shape} is even, use odd kernel shape")

        attn_kernel_shape = list(attn_kernel_shape)
        orig_attn_kernel_shape = attn_kernel_shape.copy()

        # ensure that attn kernel shape is smaller than in_shape in both dimensions
        # if necessary fix kernel_shape to be 1 less (and odd) than in_shape
        if attn_kernel_shape[0] >= out_shape[0]:
            attn_kernel_shape[0] = out_shape[0] - 1
            # ensure odd
            if attn_kernel_shape[0] % 2 == 0:
                attn_kernel_shape[0] -= 1

            # make square if original was square
            if orig_attn_kernel_shape[0] == orig_attn_kernel_shape[1]:
                attn_kernel_shape[1] = attn_kernel_shape[0]
        if attn_kernel_shape[1] >= out_shape[1]:
            attn_kernel_shape[1] = out_shape[1] - 1
            # ensure odd
            if attn_kernel_shape[1] % 2 == 0:
                attn_kernel_shape[1] -= 1

        attn_kernel_shape = tuple(attn_kernel_shape)

        self.in_shape = in_shape
        self.out_shape = out_shape
        self.in_channels = in_channels
        self.out_channels = out_channels

        if isinstance(drop_path_rates, float):
            drop_path_rates = [x.item() for x in torch.linspace(0, drop_path_rates, nrep)]

        assert len(drop_path_rates) == nrep

        self.fwd = [
            OverlapPatchMerging(
                in_shape=in_shape,
                out_shape=out_shape,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_shape=kernel_shape,
                bias=False,
            )
        ]

        for i in range(nrep):
            self.fwd.append(
                AttentionWrapper(
                    channels=out_channels,
                    shape=out_shape,
                    heads=heads,
                    pre_norm=True,
                    attention_drop_rate=att_drop_rate,
                    drop_path=drop_path_rates[i],
                    attention_mode=attention_mode,
                    kernel_shape=attn_kernel_shape,
                    bias=bias
                )
            )

            self.fwd.append(
                MixFFN(
                    out_shape,
                    inout_channels=out_channels,
                    hidden_channels=mlp_hidden_channels,
                    mlp_bias=True,
                    kernel_shape=kernel_shape,
                    conv_bias=False,
                    activation=activation,
                    use_mlp=False,
                    drop_path=drop_path_rates[i],
                )
            )

        # make sequential
        self.fwd = nn.Sequential(*self.fwd)

        # final norm
        self.norm = nn.LayerNorm((out_channels), eps=1e-05, elementwise_affine=True, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fwd(x)

        # apply norm
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)

        return x


class Upsampling(nn.Module):
    def __init__(
        self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        hidden_channels,
        mlp_bias=True,
        kernel_shape=(3, 3),
        conv_bias=False,
        activation=nn.GELU,
        use_mlp=False,
    ):
        super().__init__()
        self.out_shape = out_shape
        if use_mlp:
            self.mlp = MLP(in_channels, hidden_features=hidden_channels, out_features=out_channels, act_layer=activation, output_bias=False, drop_rate=0.0)
        else:
            self.mlp = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = nn.functional.interpolate(self.mlp(x), size=self.out_shape, mode="bilinear")
        return x


class Segformer(nn.Module):
    """
    Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks

    Parameters
    -----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    kernel_shape: tuple, int
    scale_factor: int, optional
        Scale factor to use, by default 2
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of classes, by default 3
    embed_dims : List[int], optional
        Dimension of the embeddings for each block, has to be the same length as heads
    heads : List[int], optional
        Number of heads for each block in the network, has to be the same length as embed_dims
    depths: List[in], optional
        Number of repetitions of attentions blocks and ffn mixers per layer. Has to be the same length as embed_dims and heads
    activation_function : str, optional
        Activation function to use, by default "gelu"
    embedder_kernel_shape : int, optional
        size of the encoder kernel
    use_mlp : int, optional
        Whether to use MLPs in the SFNO blocks, by default True
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"

    Example
    -----------
    >>> model = Segformer(
    ...         img_size=(128, 256),
    ...         in_chans=3,
    ...         out_chans=3,
    ...         embed_dims=[64, 128, 256, 512],
    ...         heads=[1, 2, 4, 8],
    ...         depths=[3, 4, 6, 3],
    ...         scale_factor=2,
    ...         activation_function="gelu",
    ...         kernel_shape=(3, 3),
    ...         mlp_ratio=2.0,
    ...         att_drop_rate=0.0,
    ...         drop_path_rate=0.1,
    ...         attention_mode="global",
    ))
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
    """

    def __init__(
        self,
        img_size=(128, 256),
        in_chans=3,
        out_chans=3,
        embed_dims=[64, 128, 256, 512],
        heads=[1, 2, 4, 8],
        depths=[3, 4, 6, 3],
        scale_factor=2,
        activation_function="gelu",
        kernel_shape=(3, 3),
        mlp_ratio=2.0,
        att_drop_rate=0.0,
        drop_path_rate=0.1,
        attention_mode="neighborhood",
        attn_kernel_shape=(7, 7),
        bias=True
    ):
        super().__init__()

        self.img_size = img_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.embed_dims = embed_dims
        self.heads = heads
        self.num_blocks = len(self.embed_dims)
        self.depths = depths
        self.kernel_shape = kernel_shape

        assert len(self.heads) == self.num_blocks
        assert len(self.depths) == self.num_blocks

        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

        # set up drop path rates
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]

        self.blocks = nn.ModuleList([])
        out_shape = img_size
        in_channels = in_chans
        cur = 0
        for i in range(self.num_blocks):
            out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
            out_channels = self.embed_dims[i]
            self.blocks.append(
                TransformerBlock(
                    in_shape=out_shape,
                    out_shape=out_shape_new,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    mlp_hidden_channels=int(mlp_ratio * out_channels),
                    nrep=self.depths[i],
                    heads=self.heads[i],
                    kernel_shape=kernel_shape,
                    activation=self.activation_function,
                    att_drop_rate=att_drop_rate,
                    drop_path_rates=dpr[cur : cur + self.depths[i]],
                    attention_mode=attention_mode,
                    attn_kernel_shape=attn_kernel_shape,
                    bias=bias
                )
            )
            cur += self.depths[i]
            out_shape = out_shape_new
            in_channels = out_channels

        self.upsamplers = nn.ModuleList([])
        out_shape = img_size
        for i in range(self.num_blocks):
            in_shape = self.blocks[i].out_shape
            self.upsamplers.append(
                Upsampling(
                    in_shape=in_shape,
                    out_shape=out_shape,
                    in_channels=self.embed_dims[i],
                    out_channels=self.embed_dims[i],
                    hidden_channels=int(mlp_ratio * self.embed_dims[i]),
                    mlp_bias=True,
                    kernel_shape=kernel_shape,
                    conv_bias=False,
                    activation=nn.GELU,
                )
            )

        segmentation_head_dim = sum(self.embed_dims)
        self.segmentation_head = nn.Conv2d(in_channels=segmentation_head_dim, out_channels=out_chans, kernel_size=1, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):

        # encoder:
        features = []
        feat = x
        for block in self.blocks:
            feat = block(feat)
            features.append(feat)

        # perform upsample
        upfeats = []
        for feat, upsampler in zip(features, self.upsamplers):
            upfeat = upsampler(feat)
            upfeats.append(upfeat)

        # perform concatenation
        upfeats = torch.cat(upfeats, dim=1)

        # final upsampling and prediction
        out = self.segmentation_head(upfeats)

        return out