unet.py 19.9 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp

from torch_harmonics.examples.models._layers import MLP, DropPath

from functools import partial



class DownsamplingBlock(nn.Module):
apaaris's avatar
apaaris committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """
    Downsampling block for the UNet model.
    
    Parameters
    ----------
    in_shape : tuple
        Input shape (height, width)
    out_shape : tuple
        Output shape (height, width)
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    nrep : int, optional
        Number of repetitions of conv blocks, by default 1
    kernel_shape : tuple, optional
        Kernel shape for convolutions, by default (3, 3)
    activation : callable, optional
        Activation function, by default nn.ReLU
    transform_skip : bool, optional
        Whether to transform skip connections, by default False
    drop_conv_rate : float, optional
        Dropout rate for convolutions, by default 0.
    drop_path_rate : float, optional
        Drop path rate, by default 0.
    drop_dense_rate : float, optional
        Dropout rate for dense layers, by default 0.
    downsampling_mode : str, optional
        Downsampling mode ("bilinear", "conv"), by default "bilinear"
    """
    
Boris Bonev's avatar
Boris Bonev committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    def __init__(
	    self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        nrep=1,
	    kernel_shape=(3, 3),
        activation=nn.ReLU,
        transform_skip=False,
        drop_conv_rate=0.,
        drop_path_rate=0.,
        drop_dense_rate=0.,
        downsampling_mode="bilinear",
    ):
        super().__init__()

        self.in_shape = in_shape
        self.out_shape = out_shape
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsampling_mode = downsampling_mode
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        self.fwd =[]
        for i in range(nrep):
            # conv
            self.fwd.append(
                nn.Conv2d(
                    in_channels=(in_channels if i==0 else out_channels),
                    out_channels=out_channels,
                    kernel_size=kernel_shape,
                    bias=False,
                    padding="same"
                )
            )

            if drop_conv_rate > 0.:
                self.fwd.append(
                    nn.Dropout2d(
                        p=drop_conv_rate
                    )
                )

            # batchnorm
            self.fwd.append(
                nn.BatchNorm2d(out_channels,
                               eps=1e-05,
                               momentum=0.1,
                               affine=True,
                               track_running_stats=True)
            )

            # activation  
            self.fwd.append(
                activation(),
            )

        if downsampling_mode == "conv":
            stride_h = in_shape[0] // out_shape[0]   
            stride_w = in_shape[1] // out_shape[1]
            pad_h = math.ceil(((out_shape[0] - 1) * stride_h
                            - in_shape[0]
                            + kernel_shape[0]) / 2)
            pad_w = math.ceil(((out_shape[1] - 1) * stride_w
                            - in_shape[1]
                            + kernel_shape[1]) / 2)
            self.downsample = nn.Conv2d(
                    in_channels=(in_channels if i==0 else out_channels),
                    out_channels=out_channels,
                    kernel_size=kernel_shape,
                    bias=False,
                    stride=(stride_h, stride_w),
                    padding=(pad_h, pad_w)
                )
        else:
            self.downsample = nn.Identity()

        # make sequential
        self.fwd = nn.Sequential(*self.fwd)

        # final norm
        if transform_skip or (in_channels != out_channels):
            self.transform_skip = nn.Conv2d(in_channels,
                                            out_channels,
                                            kernel_size=1,
                                            bias=True)

            if drop_dense_rate >0.:
                self.transform_skip = nn.Sequential(
                    self.transform_skip,
                    nn.Dropout2d(p=drop_dense_rate),
                )

        self.apply(self._init_weights)

    def _init_weights(self, m):
apaaris's avatar
apaaris committed
174
175
176
177
178
179
180
181
        """
        Initialize weights for the module.
        
        Parameters
        -----------
        m : torch.nn.Module
            Module to initialize weights for
        """
Boris Bonev's avatar
Boris Bonev committed
182
183
184
185
186
187
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
apaaris's avatar
apaaris committed
188
189
190
191
192
193
194
195
196
197
198
199
200
        """
        Forward pass through the DownsamplingBlock.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Output tensor after downsampling
        """
Boris Bonev's avatar
Boris Bonev committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        # skip connection
        residual = x
        if hasattr(self, "transform_skip"):
            residual = self.transform_skip(residual)

        # main path
        x = self.fwd(x)

        # add residual connection
        x = residual + self.drop_path(x)

        # downsample
        x = self.downsample(x)
        if self.downsampling_mode == "bilinear":
            x = F.interpolate(x, size=self.out_shape, mode="bilinear")
            
        return x

    
class UpsamplingBlock(nn.Module):
apaaris's avatar
apaaris committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    """
    Upsampling block for UNet architecture.
    
    Parameters
    -----------
    in_shape : tuple
        Input shape (height, width)
    out_shape : tuple
        Output shape (height, width)
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    nrep : int, optional
        Number of repetitions of conv blocks, by default 1
    kernel_shape : tuple, optional
        Kernel shape for convolutions, by default (3, 3)
    activation : callable, optional
        Activation function, by default nn.ReLU
    transform_skip : bool, optional
        Whether to transform skip connections, by default False
    drop_conv_rate : float, optional
        Dropout rate for convolutions, by default 0.
    drop_path_rate : float, optional
        Drop path rate, by default 0.
    drop_dense_rate : float, optional
        Dropout rate for dense layers, by default 0.
    upsampling_mode : str, optional
        Upsampling mode ("bilinear", "conv"), by default "bilinear"
    """
Boris Bonev's avatar
Boris Bonev committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def __init__(
        self,
        in_shape,
        out_shape,
        in_channels,
        out_channels,
        nrep=1,
        kernel_shape=(3, 3),
        activation=nn.ReLU,
        transform_skip=False,
        drop_conv_rate=0.,
        drop_path_rate=0.,
        drop_dense_rate=0.,
        upsampling_mode="bilinear",
    ):
        super().__init__()

        self.in_shape = in_shape
        self.out_shape = out_shape
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.upsampling_mode = upsampling_mode

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        if in_shape != out_shape:
            if upsampling_mode == "conv":
                stride_h = out_shape[0] // in_shape[0]   
                stride_w = out_shape[1] // in_shape[1]
                pad_h = math.ceil(((in_shape[0] - 1) * stride_h
                                - in_shape[0]
                                + kernel_shape[0]) / 2)
                pad_w = math.ceil(((in_shape[1] - 1) * stride_w
                                - in_shape[1]
                                + kernel_shape[1]) / 2)
                self.upsample = nn.Sequential(
                    nn.ConvTranspose2d(
                        in_channels=out_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_shape,
                        stride=(stride_h, stride_w),
                        padding=(pad_h, pad_w)
                    ),
                    nn.BatchNorm2d(out_channels,
                                   eps=1e-05,
                                   momentum=0.1,
                                   affine=True,
                                   track_running_stats=True),
                    activation(),
                    nn.Conv2d(
                        in_channels=out_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_shape,
                        bias=False,
                        padding="same")
                )

        self.fwd =[]
        for i in range(nrep):
            # conv
            self.fwd.append(
                nn.Conv2d(
                        in_channels=(in_channels if i == 0 else out_channels),
                        out_channels=out_channels,
                        kernel_size=kernel_shape,
                        bias=False,
                        padding="same")
            )

            if drop_conv_rate > 0.:
                self.fwd.append(
                    nn.Dropout2d(
                        p=drop_conv_rate
                    )
                )
            
            # batchnorm
            self.fwd.append(
                nn.BatchNorm2d((out_channels if i==nrep-1 else in_channels),
                                eps=1e-05,
	                            momentum=0.1,
                                affine=True,
                                track_running_stats=True)
            )

            # activation
            self.fwd.append(
                activation(),
            )
            
        # make sequential
        self.fwd = nn.Sequential(*self.fwd)

        # final norm
        if transform_skip or (in_channels != out_channels):
            self.transform_skip = nn.Conv2d(in_channels,
                                            out_channels,
                                            kernel_size=1,
                                            bias=True)
            if drop_dense_rate >0.:
                self.transform_skip = nn.Sequential(
                    self.transform_skip,
                    nn.Dropout2d(p=drop_dense_rate),
                )

        self.apply(self._init_weights)

    def _init_weights(self, m):
apaaris's avatar
apaaris committed
359
360
361
362
363
364
365
366
        """
        Initialize weights for the module.
        
        Parameters
        -----------
        m : torch.nn.Module
            Module to initialize weights for
        """
Boris Bonev's avatar
Boris Bonev committed
367
368
369
370
371
372
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
apaaris's avatar
apaaris committed
373
374
375
376
377
378
379
380
381
382
383
384
385
        """
        Forward pass through the UpsamplingBlock.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        torch.Tensor
            Output tensor after upsampling
        """
Boris Bonev's avatar
Boris Bonev committed
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        # skip connection
        residual = x
        if hasattr(self, "transform_skip"):
            residual = self.transform_skip(residual)
        # main path
        x = residual + self.drop_path(self.fwd(x))

        # upsampling
        if self.upsampling_mode=="bilinear":
            x = F.interpolate(x, size=self.out_shape, mode="bilinear")
        else:
            x = self.upsample(x)
        return x


class UNet(nn.Module):
    """
    Spherical segformer model designed to approximate mappings from spherical signals to spherical segmentation masks

    Parameters
    -----------
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    kernel_shape: tuple, int
apaaris's avatar
apaaris committed
410
        Kernel shape for convolutions
Boris Bonev's avatar
Boris Bonev committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    scale_factor: int, optional
        Scale factor to use, by default 2
    in_chans : int, optional
        Number of input channels, by default 3
    num_classes : int, optional
        Number of classes, by default 3
    embed_dims : List[int], optional
        Dimension of the embeddings for each block, has to be the same length as depths
    depths: List[in], optional
        Number of repetitions of conv blocks and ffn mixers per layer. Has to be the same length as embed_dims
    activation_function : str, optional
        Activation function to use, by default "relu"
    embedder_kernel_shape : int, optional
        size of the encoder kernel
    use_mlp : int, optional
        Whether to use MLPs in the SFNO blocks, by default True
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"

    Example
    -----------
    >>> model = UNet(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         num_classes=2,
apaaris's avatar
apaaris committed
443
444
445
    ...         embed_dims=[16, 32, 64, 128],
    ...         depths=[2, 2, 2, 2],
    ...         use_mlp=True,)
Boris Bonev's avatar
Boris Bonev committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
    """
    def __init__(
        self,
        img_shape=(128, 256),
        in_chans=3,
        num_classes=3,
        embed_dims=[64, 128, 256, 512],
        depths=[2, 2, 2, 2],
        scale_factor=2,
        activation_function="relu",
        kernel_shape=(3, 3),
        transform_skip=False,
        drop_conv_rate=0.1,
        drop_path_rate=0.1,
        drop_dense_rate=0.5,
        downsampling_mode="bilinear",
        upsampling_mode="bilinear",
    ):
        super().__init__()

        self.img_shape = img_shape
        self.in_chans = in_chans
        self.num_classes = num_classes
        self.embed_dims = embed_dims
        self.num_blocks = len(self.embed_dims)
        self.depths = depths
        self.kernel_shape = kernel_shape

        assert(len(self.depths) == self.num_blocks)
        
        # activation function
        if activation_function == "relu":
            self.activation_function = nn.ReLU
        elif activation_function == "gelu":
            self.activation_function = nn.GELU
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

        # set up drop path rates
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)]

        self.dblocks = nn.ModuleList([])
        out_shape = img_shape
        in_channels = in_chans
        for i in range(self.num_blocks):
            out_shape_new = (out_shape[0] // scale_factor, out_shape[1] // scale_factor)
            out_channels = self.embed_dims[i]
            self.dblocks.append(
                DownsamplingBlock(
                    in_shape=out_shape,
                    out_shape=out_shape_new,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    nrep=self.depths[i],
                    kernel_shape=kernel_shape,
                    activation=self.activation_function,
                    drop_conv_rate=drop_conv_rate,
                    drop_path_rate=dpr[i],
                    drop_dense_rate=drop_dense_rate,
                    transform_skip=transform_skip,
                    downsampling_mode=downsampling_mode,
                )
            )
            out_shape = out_shape_new
            in_channels = out_channels

        self.ublocks = nn.ModuleList([])
        for i in range(self.num_blocks-1, -1, -1):
            in_shape = self.dblocks[i].out_shape
            out_shape = self.dblocks[i].in_shape
            in_channels = self.dblocks[i].out_channels
            if i != self.num_blocks-1:
                in_channels = 2 * in_channels
            out_channels = self.dblocks[i].in_channels
            if i==0:
                out_channels = self.embed_dims[0]
            self.ublocks.append(
                UpsamplingBlock(
                    in_shape=in_shape,
                    out_shape=out_shape,
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_shape=kernel_shape,
                    activation=self.activation_function,
                    drop_conv_rate=drop_conv_rate,
                    drop_path_rate=0.,
                    drop_dense_rate=drop_dense_rate,
                    transform_skip=transform_skip,
                    upsampling_mode=upsampling_mode,
                )
            )

        self.head = nn.Conv2d(self.embed_dims[0], self.num_classes, kernel_size=1, bias=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
apaaris's avatar
apaaris committed
548
549
550
551
552
553
554
555
        """
        Initialize weights for the module.
        
        Parameters
        -----------
        m : torch.nn.Module
            Module to initialize weights for
        """
Boris Bonev's avatar
Boris Bonev committed
556
557
558
559
560
561
562
563
564
565
        if isinstance(m, nn.Conv2d):
            nn.init.trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, x):
apaaris's avatar
apaaris committed
566
567
568
569
570
571
572
573
574
575
576
577
578
        """
        Forward pass through the UNet model.
        
        Parameters
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, in_chans, height, width)
            
        Returns
        -------
        torch.Tensor
            Output tensor of shape (batch_size, num_classes, height, width)
        """
Boris Bonev's avatar
Boris Bonev committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        # encoder:
        features = []
        feat = x
        for dblock in self.dblocks:
            feat = dblock(feat)
            features.append(feat)

        # reverse list
        features = features[::-1]
        
        # perform upsample
        ufeat = self.ublocks[0](features[0])
        for feat, ublock in zip(features[1:], self.ublocks[1:]):
            ufeat = ublock(torch.cat([feat, ufeat], dim=1))

        # last layer
        out = self.head(ufeat)

        return out

if __name__ == "__main__":
    model = UNet(
             img_shape=(128, 256),
             scale_factor=2,
             in_chans=2,
             embed_dims=[64, 128, 256],
             depths=[2, 2, 2])
    print(model)
    print(model(torch.randn(1, 2, 128, 256)).shape)