sfno.py 20.9 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn

from torch_harmonics import *

37
from .layers import *
Boris Bonev's avatar
Boris Bonev committed
38

Srikumar Sastry's avatar
Srikumar Sastry committed
39
40
from functools import partial

41

Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
46
47
48
49
50
class SpectralFilterLayer(nn.Module):
    """
    Fourier layer. Contains the convolution part of the FNO/SFNO
    """

    def __init__(
        self,
        forward_transform,
        inverse_transform,
51
52
53
        input_dim,
        output_dim,
        gain = 2.,
Boris Bonev's avatar
Boris Bonev committed
54
        operator_type = "diagonal",
Boris Bonev's avatar
Boris Bonev committed
55
56
57
58
        hidden_size_factor = 2,
        factorization = None,
        separable = False,
        rank = 1e-2,
59
        bias = True):
60
        super(SpectralFilterLayer, self).__init__()
Boris Bonev's avatar
Boris Bonev committed
61

62
        if factorization is None:
Boris Bonev's avatar
Boris Bonev committed
63
64
            self.filter = SpectralConvS2(forward_transform,
                                         inverse_transform,
65
66
67
                                         input_dim,
                                         output_dim,
                                         gain = gain,
Boris Bonev's avatar
Boris Bonev committed
68
                                         operator_type = operator_type,
69
                                         bias = bias)
70

71
        elif factorization is not None:
Boris Bonev's avatar
Boris Bonev committed
72
73
            self.filter = FactorizedSpectralConvS2(forward_transform,
                                                   inverse_transform,
74
75
76
                                                   input_dim,
                                                   output_dim,
                                                   gain = gain,
Boris Bonev's avatar
Boris Bonev committed
77
78
79
80
                                                   operator_type = operator_type,
                                                   rank = rank,
                                                   factorization = factorization,
                                                   separable = separable,
81
                                                   bias = bias)
Boris Bonev's avatar
Boris Bonev committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

        else:
            raise(NotImplementedError)

    def forward(self, x):
        return self.filter(x)

class SphericalFourierNeuralOperatorBlock(nn.Module):
    """
    Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
    """
    def __init__(
            self,
            forward_transform,
            inverse_transform,
97
98
            input_dim,
            output_dim,
Boris Bonev's avatar
Boris Bonev committed
99
            operator_type = "driscoll-healy",
Boris Bonev's avatar
Boris Bonev committed
100
101
102
            mlp_ratio = 2.,
            drop_rate = 0.,
            drop_path = 0.,
103
            act_layer = nn.ReLU,
Boris Bonev's avatar
Boris Bonev committed
104
            norm_layer = nn.Identity,
Boris Bonev's avatar
Boris Bonev committed
105
106
107
            factorization = None,
            separable = False,
            rank = 128,
Boris Bonev's avatar
Boris Bonev committed
108
109
            inner_skip = "linear",
            outer_skip = None,
110
            use_mlp = True):
Boris Bonev's avatar
Boris Bonev committed
111
        super(SphericalFourierNeuralOperatorBlock, self).__init__()
112
113
114
115
116
117
118
119

        if act_layer == nn.Identity:
            gain_factor = 1.0
        else:
            gain_factor = 2.0

        if inner_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.0
120

Boris Bonev's avatar
Boris Bonev committed
121
122
123
        # convolution layer
        self.filter = SpectralFilterLayer(forward_transform,
                                          inverse_transform,
124
125
126
                                          input_dim,
                                          output_dim,
                                          gain = gain_factor,
Boris Bonev's avatar
Boris Bonev committed
127
128
129
130
131
                                          operator_type = operator_type,
                                          hidden_size_factor = mlp_ratio,
                                          factorization = factorization,
                                          separable = separable,
                                          rank = rank,
132
                                          bias = True)
Boris Bonev's avatar
Boris Bonev committed
133

Boris Bonev's avatar
Boris Bonev committed
134
        if inner_skip == "linear":
135
136
            self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
            nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim))
Boris Bonev's avatar
Boris Bonev committed
137
        elif inner_skip == "identity":
138
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
139
            self.inner_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
140
141
142
143
        elif inner_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {inner_skip}")
Boris Bonev's avatar
Boris Bonev committed
144

145
        self.act_layer = act_layer()
Boris Bonev's avatar
Boris Bonev committed
146
147
148

        # first normalisation layer
        self.norm0 = norm_layer()
149

Boris Bonev's avatar
Boris Bonev committed
150
151
        # dropout
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
152
153
154
155

        gain_factor = 1.0
        if outer_skip == "linear" or inner_skip == "identity":
            gain_factor /= 2.
156

Boris Bonev's avatar
Boris Bonev committed
157
        if use_mlp == True:
158
159
160
            mlp_hidden_dim = int(output_dim * mlp_ratio)
            self.mlp = MLP(in_features = output_dim,
                           out_features = input_dim,
Boris Bonev's avatar
Boris Bonev committed
161
162
163
                           hidden_features = mlp_hidden_dim,
                           act_layer = act_layer,
                           drop_rate = drop_rate,
164
165
                           checkpointing = False,
                           gain = gain_factor)
Boris Bonev's avatar
Boris Bonev committed
166

Boris Bonev's avatar
Boris Bonev committed
167
        if outer_skip == "linear":
168
169
            self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
            torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim))
Boris Bonev's avatar
Boris Bonev committed
170
        elif outer_skip == "identity":
171
            assert input_dim == output_dim
Boris Bonev's avatar
Boris Bonev committed
172
            self.outer_skip = nn.Identity()
Boris Bonev's avatar
Boris Bonev committed
173
174
175
176
        elif outer_skip == "none":
            pass
        else:
            raise ValueError(f"Unknown skip connection type {outer_skip}")
Boris Bonev's avatar
Boris Bonev committed
177

Boris Bonev's avatar
Boris Bonev committed
178
179
180
        # second normalisation layer
        self.norm1 = norm_layer()

181
182
183
184
185
186
187
188
189
190
191
    # def init_weights(self, scale):
    #     if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d):
    #         gain_factor = 1.
    #         scale = (gain_factor / embed_dim)**0.5
    #         nn.init.normal_(self.inner_skip.weight, mean=0., std=scale)
    #         self.filter.filter.init_weights(scale)
    #     else:
    #         gain_factor = 2.
    #         scale = (gain_factor / embed_dim)**0.5
    #         self.filter.filter.init_weights(scale)

Boris Bonev's avatar
Boris Bonev committed
192
193
194
195
    def forward(self, x):

        x, residual = self.filter(x)

196
197
        x = self.norm0(x)

Boris Bonev's avatar
Boris Bonev committed
198
        if hasattr(self, "inner_skip"):
199
            x = x + self.inner_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
200

Boris Bonev's avatar
Boris Bonev committed
201
        if hasattr(self, "act_layer"):
Boris Bonev's avatar
Boris Bonev committed
202
203
            x = self.act_layer(x)

Boris Bonev's avatar
Boris Bonev committed
204
        if hasattr(self, "mlp"):
Boris Bonev's avatar
Boris Bonev committed
205
206
            x = self.mlp(x)

207
208
        x = self.norm1(x)

Boris Bonev's avatar
Boris Bonev committed
209
210
        x = self.drop_path(x)

Boris Bonev's avatar
Boris Bonev committed
211
        if hasattr(self, "outer_skip"):
212
            x = x + self.outer_skip(residual)
Boris Bonev's avatar
Boris Bonev committed
213

Boris Bonev's avatar
Boris Bonev committed
214
215
216
217
218
219
220
221
222
223
224
225
        return x

class SphericalFourierNeuralOperatorNet(nn.Module):
    """
    SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
    both linear and non-linear variants.

    Parameters
    ----------
    spectral_transform : str, optional
        Type of spectral transformation to use, by default "sht"
    operator_type : str, optional
Boris Bonev's avatar
Boris Bonev committed
226
        Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
Boris Bonev's avatar
Boris Bonev committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    img_shape : tuple, optional
        Shape of the input channels, by default (128, 256)
    scale_factor : int, optional
        Scale factor to use, by default 3
    in_chans : int, optional
        Number of input channels, by default 3
    out_chans : int, optional
        Number of output channels, by default 3
    embed_dim : int, optional
        Dimension of the embeddings, by default 256
    num_layers : int, optional
        Number of layers in the network, by default 4
    activation_function : str, optional
        Activation function to use, by default "gelu"
    encoder_layers : int, optional
        Number of layers in the encoder, by default 1
    use_mlp : int, optional
Boris Bonev's avatar
Boris Bonev committed
244
        Whether to use MLPs in the SFNO blocks, by default True
Boris Bonev's avatar
Boris Bonev committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    mlp_ratio : int, optional
        Ratio of MLP to use, by default 2.0
    drop_rate : float, optional
        Dropout rate, by default 0.0
    drop_path_rate : float, optional
        Dropout path rate, by default 0.0
    normalization_layer : str, optional
        Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
    hard_thresholding_fraction : float, optional
        Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
    big_skip : bool, optional
        Whether to add a single large skip connection, by default True
    rank : float, optional
        Rank of the approximation, by default 1.0
    factorization : Any, optional
        Type of factorization to use, by default None
    separable : bool, optional
        Whether to use separable convolutions, by default False
    rank : (int, Tuple[int]), optional
        If a factorization is used, which rank to use. Argument is passed to tensorly
    pos_embed : bool, optional
        Whether to use positional embedding, by default True

    Example:
    --------
    >>> model = SphericalFourierNeuralOperatorNet(
    ...         img_shape=(128, 256),
    ...         scale_factor=4,
    ...         in_chans=2,
    ...         out_chans=2,
    ...         embed_dim=16,
Boris Bonev's avatar
Boris Bonev committed
276
    ...         num_layers=4,
Boris Bonev's avatar
Boris Bonev committed
277
278
279
280
281
282
283
    ...         use_mlp=True,)
    >>> model(torch.randn(1, 2, 128, 256)).shape
    torch.Size([1, 2, 128, 256])
    """

    def __init__(
            self,
Boris Bonev's avatar
Boris Bonev committed
284
285
            spectral_transform = "sht",
            operator_type = "driscoll-healy",
Boris Bonev's avatar
Boris Bonev committed
286
            img_size = (128, 256),
287
            grid = "equiangular",
Boris Bonev's avatar
Boris Bonev committed
288
289
290
291
292
            scale_factor = 3,
            in_chans = 3,
            out_chans = 3,
            embed_dim = 256,
            num_layers = 4,
293
            activation_function = "relu",
Boris Bonev's avatar
Boris Bonev committed
294
295
296
297
298
            encoder_layers = 1,
            use_mlp = True,
            mlp_ratio = 2.,
            drop_rate = 0.,
            drop_path_rate = 0.,
Boris Bonev's avatar
Boris Bonev committed
299
            normalization_layer = "none",
Boris Bonev's avatar
Boris Bonev committed
300
301
            hard_thresholding_fraction = 1.0,
            use_complex_kernels = True,
302
            big_skip = False,
Boris Bonev's avatar
Boris Bonev committed
303
304
305
            factorization = None,
            separable = False,
            rank = 128,
306
            pos_embed = False):
Boris Bonev's avatar
Boris Bonev committed
307
308
309
310
311
312

        super(SphericalFourierNeuralOperatorNet, self).__init__()

        self.spectral_transform = spectral_transform
        self.operator_type = operator_type
        self.img_size = img_size
313
        self.grid = grid
Boris Bonev's avatar
Boris Bonev committed
314
315
316
        self.scale_factor = scale_factor
        self.in_chans = in_chans
        self.out_chans = out_chans
317
        self.embed_dim = embed_dim
Boris Bonev's avatar
Boris Bonev committed
318
319
320
321
322
323
324
325
326
327
328
        self.num_layers = num_layers
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.normalization_layer = normalization_layer
        self.use_mlp = use_mlp
        self.encoder_layers = encoder_layers
        self.big_skip = big_skip
        self.factorization = factorization
        self.separable = separable,
        self.rank = rank

        # activation function
Boris Bonev's avatar
Boris Bonev committed
329
        if activation_function == "relu":
Boris Bonev's avatar
Boris Bonev committed
330
            self.activation_function = nn.ReLU
Boris Bonev's avatar
Boris Bonev committed
331
        elif activation_function == "gelu":
Boris Bonev's avatar
Boris Bonev committed
332
            self.activation_function = nn.GELU
333
334
335
        # for debugging purposes
        elif activation_function == "identity":
            self.activation_function = nn.Identity
Boris Bonev's avatar
Boris Bonev committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        else:
            raise ValueError(f"Unknown activation function {activation_function}")

        # compute downsampled image size
        self.h = self.img_size[0] // scale_factor
        self.w = self.img_size[1] // scale_factor

        # dropout
        self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]

        # pick norm layer
        if self.normalization_layer == "layer_norm":
            norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
            norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
        elif self.normalization_layer == "instance_norm":
            norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
353
            norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
Boris Bonev's avatar
Boris Bonev committed
354
355
356
357
        elif self.normalization_layer == "none":
            norm_layer0 = nn.Identity
            norm_layer1 = norm_layer0
        else:
358
            raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
Boris Bonev's avatar
Boris Bonev committed
359

360
        if pos_embed == "latlon" or pos_embed==True:
Boris Bonev's avatar
Boris Bonev committed
361
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
362
363
364
365
366
367
368
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "lat":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], 1))
            nn.init.constant_(self.pos_embed, 0.0)
        elif pos_embed == "const":
            self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
            nn.init.constant_(self.pos_embed, 0.0)
Boris Bonev's avatar
Boris Bonev committed
369
370
371
        else:
            self.pos_embed = None

372
373
374
375
376
377
378
379
380
381
382
383
384
        # # encoder
        # encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
        # encoder = MLP(in_features = self.in_chans,
        #               out_features = self.embed_dim,
        #               hidden_features = encoder_hidden_dim,
        #               act_layer = self.activation_function,
        #               drop_rate = drop_rate,
        #               checkpointing = False)
        # self.encoder = encoder


        # construct an encoder with num_encoder_layers
        num_encoder_layers = 1
Boris Bonev's avatar
Boris Bonev committed
385
        encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        current_dim = self.in_chans
        encoder_layers = []
        for l in range(num_encoder_layers-1):
            fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
            scale = math.sqrt(2. / current_dim)
            nn.init.normal_(fc.weight, mean=0., std=scale)
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            encoder_layers.append(fc)
            encoder_layers.append(self.activation_function())
            current_dim = encoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
        scale = math.sqrt(1. / current_dim)
        nn.init.normal_(fc.weight, mean=0., std=scale)
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        encoder_layers.append(fc)
        self.encoder = nn.Sequential(*encoder_layers)
405

Boris Bonev's avatar
Boris Bonev committed
406
        # prepare the spectral transform
Boris Bonev's avatar
Boris Bonev committed
407
        if self.spectral_transform == "sht":
Boris Bonev's avatar
Boris Bonev committed
408
409

            modes_lat = int(self.h * self.hard_thresholding_fraction)
410
411
            modes_lon = int(self.w//2 * self.hard_thresholding_fraction)
            modes_lat = modes_lon = min(modes_lat, modes_lon)
Boris Bonev's avatar
Boris Bonev committed
412

413
414
            self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
            self.itrans_up  = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
Boris Bonev's avatar
Boris Bonev committed
415
416
            self.trans      = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
            self.itrans     = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
Boris Bonev's avatar
Boris Bonev committed
417

Boris Bonev's avatar
Boris Bonev committed
418
        elif self.spectral_transform == "fft":
Boris Bonev's avatar
Boris Bonev committed
419
420
421
422
423
424
425
426

            modes_lat = int(self.h * self.hard_thresholding_fraction)
            modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)

            self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
            self.itrans_up  = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
            self.trans      = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
            self.itrans     = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
427

Boris Bonev's avatar
Boris Bonev committed
428
        else:
Boris Bonev's avatar
Boris Bonev committed
429
            raise(ValueError("Unknown spectral transform"))
Boris Bonev's avatar
Boris Bonev committed
430
431
432
433
434
435
436
437
438
439

        self.blocks = nn.ModuleList([])
        for i in range(self.num_layers):

            first_layer = i == 0
            last_layer = i == self.num_layers-1

            forward_transform = self.trans_down if first_layer else self.trans
            inverse_transform = self.itrans_up if last_layer else self.itrans

440
441
            inner_skip = "none"
            outer_skip = "identity"
Boris Bonev's avatar
Boris Bonev committed
442
443

            if first_layer:
Boris Bonev's avatar
Boris Bonev committed
444
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
445
            elif last_layer:
Boris Bonev's avatar
Boris Bonev committed
446
                norm_layer = norm_layer0
Boris Bonev's avatar
Boris Bonev committed
447
            else:
Boris Bonev's avatar
Boris Bonev committed
448
                norm_layer = norm_layer1
Boris Bonev's avatar
Boris Bonev committed
449
450
451
452

            block = SphericalFourierNeuralOperatorBlock(forward_transform,
                                                        inverse_transform,
                                                        self.embed_dim,
453
                                                        self.embed_dim,
Boris Bonev's avatar
Boris Bonev committed
454
455
456
457
458
459
460
461
462
463
464
                                                        operator_type = self.operator_type,
                                                        mlp_ratio = mlp_ratio,
                                                        drop_rate = drop_rate,
                                                        drop_path = dpr[i],
                                                        act_layer = self.activation_function,
                                                        norm_layer = norm_layer,
                                                        inner_skip = inner_skip,
                                                        outer_skip = outer_skip,
                                                        use_mlp = use_mlp,
                                                        factorization = self.factorization,
                                                        separable = self.separable,
465
                                                        rank = self.rank)
Boris Bonev's avatar
Boris Bonev committed
466
467
468

            self.blocks.append(block)

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        # # decoder
        # decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
        # self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
        #                    out_features = self.out_chans,
        #                    hidden_features = decoder_hidden_dim,
        #                    act_layer = self.activation_function,
        #                    drop_rate = drop_rate,
        #                    checkpointing = False)

        # construct an decoder with num_decoder_layers
        num_decoder_layers = 1
        decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
        current_dim = self.embed_dim + self.big_skip*self.in_chans
        decoder_layers = []
        for l in range(num_decoder_layers-1):
            fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
            # initialize the weights correctly
            scale = math.sqrt(2. / current_dim)
            nn.init.normal_(fc.weight, mean=0., std=scale)
            if fc.bias is not None:
                nn.init.constant_(fc.bias, 0.0)
            decoder_layers.append(fc)
            decoder_layers.append(self.activation_function())
            current_dim = decoder_hidden_dim
        fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
        scale = math.sqrt(1. / current_dim)
        nn.init.normal_(fc.weight, mean=0., std=scale)
        if fc.bias is not None:
            nn.init.constant_(fc.bias, 0.0)
        decoder_layers.append(fc)
        self.decoder = nn.Sequential(*decoder_layers)
Boris Bonev's avatar
Boris Bonev committed
500
501
502

    @torch.jit.ignore
    def no_weight_decay(self):
Boris Bonev's avatar
Boris Bonev committed
503
        return {"pos_embed", "cls_token"}
Boris Bonev's avatar
Boris Bonev committed
504
505
506
507
508
509
510

    def forward_features(self, x):

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
511

Boris Bonev's avatar
Boris Bonev committed
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        return x

    def forward(self, x):

        if self.big_skip:
            residual = x

        x = self.encoder(x)

        if self.pos_embed is not None:
            x = x + self.pos_embed

        x = self.forward_features(x)

        if self.big_skip:
            x = torch.cat((x, residual), dim=1)

        x = self.decoder(x)

        return x
532
533