Unverified Commit 7fb0c483 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/sfno update (#10)

* reworked SFNO example

* updated changelog
parent cec07d7a
......@@ -5,6 +5,7 @@
### v0.6.3
* Adding gradient check in unit tests
* Updated SFNO example
### v0.6.2
......
......@@ -334,7 +334,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
# SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector')
num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy')
models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal')
# FNO models
......
This diff is collapsed.
......@@ -36,32 +36,27 @@ Contains complex contractions wrapped into jit for harmonic layers
"""
@torch.jit.script
def compl_contract2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,kixr->srbkx", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def contract_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkx", ac, bc)
res = torch.einsum("bixy,kixy->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def compl_contract_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bins,kinr->srbkn", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
def contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def compl_contract_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def contract_blockdiag(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bin,kin->bkn", ac, bc)
res = torch.einsum("bixy,kixyz->bkxz", ac, bc)
return torch.view_as_real(res)
# Helper routines for spherical MLPs
# Helper routines for the non-linear FNOs (Attention-like)
@torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b)
......@@ -124,18 +119,3 @@ def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
......@@ -59,7 +59,7 @@ def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
......@@ -92,7 +92,7 @@ def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'vector':
elif operator_type == 'driscoll-healy':
factor_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
......@@ -148,7 +148,7 @@ def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector':
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
......
......@@ -40,8 +40,6 @@ from torch_harmonics import *
from .contractions import *
from .activations import *
from .factorizations import get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
......@@ -207,7 +205,7 @@ class InverseRealFFT2(nn.Module):
def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class SpectralConvS2(nn.Module):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
......@@ -221,7 +219,95 @@ class SpectralConvS2(nn.Module):
in_channels,
out_channels,
scale = 'auto',
operator_type = 'diagonal',
operator_type = 'driscoll-healy',
lr_scale_exponent = 0,
bias = False):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels, out_channels]
if self.operator_type == 'diagonal':
weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract
elif self.operator_type == 'block-diagonal':
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract
elif self.operator_type == 'driscoll-healy':
weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
# rescale the learning rate for better training of spectral parameters
lr_scale = (torch.arange(self.modes_lat)+1).reshape(-1, 1)**(lr_scale_exponent)
self.register_buffer("lr_scale", lr_scale)
# self.weight.register_hook(lambda grad: self.lr_scale*grad)
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = torch.view_as_real(x)
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
if hasattr(self, 'bias'):
x = x + self.bias
x = x.type(dtype)
return x, residual
class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
scale = 'auto',
operator_type = 'driscoll-healy',
rank = 0.2,
factorization = None,
separable = False,
......@@ -231,7 +317,7 @@ class SpectralConvS2(nn.Module):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (1 / (in_channels * out_channels))
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
......@@ -266,7 +352,7 @@ class SpectralConvS2(nn.Module):
weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == 'block-diagonal':
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == 'vector':
elif self.operator_type == 'driscoll-healy':
weight_shape += [self.modes_lat]
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
......@@ -278,6 +364,8 @@ class SpectralConvS2(nn.Module):
# initialization of weights
self.weight.normal_(0, scale)
# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias:
......@@ -289,7 +377,6 @@ class SpectralConvS2(nn.Module):
dtype = x.dtype
x = x.float()
residual = x
B, C, H, W = x.shape
with amp.autocast(enabled=False):
x = self.forward_transform(x)
......@@ -467,7 +554,7 @@ class SpectralAttentionS2(nn.Module):
for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
elif operator_type == 'vector':
elif operator_type == 'driscoll-healy':
self.mul_add_handle = compl_exp_muladd2d_fwd
self.mul_handle = compl_exp_mul2d_fwd
......
......@@ -48,20 +48,21 @@ class SpectralFilterLayer(nn.Module):
forward_transform,
inverse_transform,
embed_dim,
filter_type = 'non-linear',
operator_type = 'diagonal',
filter_type = "non-linear",
operator_type = "diagonal",
sparsity_threshold = 0.0,
use_complex_kernels = True,
hidden_size_factor = 2,
lr_scale_exponent = 0,
factorization = None,
separable = False,
rank = 1e-2,
complex_activation = 'real',
complex_activation = "real",
spectral_layers = 1,
drop_rate = 0):
super(SpectralFilterLayer, self).__init__()
if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT):
if filter_type == "non-linear" and isinstance(forward_transform, RealSHT):
self.filter = SpectralAttentionS2(forward_transform,
inverse_transform,
embed_dim,
......@@ -73,7 +74,7 @@ class SpectralFilterLayer(nn.Module):
drop_rate = drop_rate,
bias = False)
elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2):
elif filter_type == "non-linear" and isinstance(forward_transform, RealFFT2):
self.filter = SpectralAttention2d(forward_transform,
inverse_transform,
embed_dim,
......@@ -85,16 +86,25 @@ class SpectralFilterLayer(nn.Module):
drop_rate = drop_rate,
bias = False)
elif filter_type == 'linear':
elif filter_type == "linear" and factorization is None:
self.filter = SpectralConvS2(forward_transform,
inverse_transform,
embed_dim,
embed_dim,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
lr_scale_exponent = lr_scale_exponent,
bias = True)
elif filter_type == "linear" and factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform,
embed_dim,
embed_dim,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
bias = True)
else:
raise(NotImplementedError)
......@@ -111,29 +121,27 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
forward_transform,
inverse_transform,
embed_dim,
filter_type = 'non-linear',
operator_type = 'diagonal',
filter_type = "non-linear",
operator_type = "driscoll-healy",
mlp_ratio = 2.,
drop_rate = 0.,
drop_path = 0.,
act_layer = nn.GELU,
norm_layer = (nn.LayerNorm, nn.LayerNorm),
norm_layer = nn.Identity,
sparsity_threshold = 0.0,
use_complex_kernels = True,
lr_scale_exponent = 0,
factorization = None,
separable = False,
rank = 128,
inner_skip = 'linear',
outer_skip = None, # None, nn.linear or nn.Identity
inner_skip = "linear",
outer_skip = None,
concat_skip = False,
use_mlp = True,
complex_activation = 'real',
complex_activation = "real",
spectral_layers = 3):
super(SphericalFourierNeuralOperatorBlock, self).__init__()
# norm layer
self.norm0 = norm_layer[0]() #((h,w))
# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
inverse_transform,
......@@ -143,6 +151,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels,
hidden_size_factor = mlp_ratio,
lr_scale_exponent = lr_scale_exponent,
factorization = factorization,
separable = separable,
rank = rank,
......@@ -150,24 +159,28 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
spectral_layers = spectral_layers,
drop_rate = drop_rate)
if inner_skip == 'linear':
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif inner_skip == 'identity':
elif inner_skip == "identity":
self.inner_skip = nn.Identity()
elif inner_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
self.concat_skip = concat_skip
if concat_skip and inner_skip is not None:
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
if filter_type == 'linear' or filter_type == 'local':
if filter_type == "linear":
self.act_layer = act_layer()
# first normalisation layer
self.norm0 = norm_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# norm layer
self.norm1 = norm_layer[1]() #((h,w))
if use_mlp == True:
mlp_hidden_dim = int(embed_dim * mlp_ratio)
......@@ -177,44 +190,51 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
drop_rate = drop_rate,
checkpointing = False)
if outer_skip == 'linear':
if outer_skip == "linear":
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif outer_skip == 'identity':
elif outer_skip == "identity":
self.outer_skip = nn.Identity()
elif outer_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
if concat_skip and outer_skip is not None:
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x):
x = self.norm0(x)
x, residual = self.filter(x)
if hasattr(self, 'inner_skip'):
if hasattr(self, "inner_skip"):
if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x)
else:
x = x + self.inner_skip(residual)
if hasattr(self, 'act_layer'):
if hasattr(self, "act_layer"):
x = self.act_layer(x)
x = self.norm1(x)
x = self.norm0(x)
if hasattr(self, 'mlp'):
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.drop_path(x)
if hasattr(self, 'outer_skip'):
if hasattr(self, "outer_skip"):
if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x)
else:
x = x + self.outer_skip(residual)
x = self.norm1(x)
return x
class SphericalFourierNeuralOperatorNet(nn.Module):
......@@ -229,7 +249,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('vector', 'diagonal'), by default "vector"
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
scale_factor : int, optional
......@@ -247,7 +267,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers : int, optional
Number of layers in the encoder, by default 1
use_mlp : int, optional
Whether to use MLP, by default True
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
......@@ -266,6 +286,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
lr_scale_exponent : float, optional
exponential rescaling of spectral coefficients, by default 0.0 (no rescaling)
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
......@@ -287,10 +309,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=2,
... encoder_layers=1,
... num_blocks=4,
... spectral_layers=2,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
......@@ -298,30 +317,31 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__(
self,
filter_type = 'linear',
spectral_transform = 'sht',
operator_type = 'vector',
filter_type = "linear",
spectral_transform = "sht",
operator_type = "driscoll-healy",
img_size = (128, 256),
scale_factor = 3,
in_chans = 3,
out_chans = 3,
embed_dim = 256,
num_layers = 4,
activation_function = 'gelu',
activation_function = "gelu",
encoder_layers = 1,
use_mlp = True,
mlp_ratio = 2.,
drop_rate = 0.,
drop_path_rate = 0.,
sparsity_threshold = 0.0,
normalization_layer = 'instance_norm',
normalization_layer = "none",
hard_thresholding_fraction = 1.0,
use_complex_kernels = True,
big_skip = True,
lr_scale_exponent = 0,
factorization = None,
separable = False,
rank = 128,
complex_activation = 'real',
complex_activation = "real",
spectral_layers = 2,
pos_embed = True):
......@@ -342,6 +362,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.use_mlp = use_mlp
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.lr_scale_exponent = lr_scale_exponent
self.factorization = factorization
self.separable = separable,
self.rank = rank
......@@ -349,9 +370,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.spectral_layers = spectral_layers
# activation function
if activation_function == 'relu':
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == 'gelu':
elif activation_function == "gelu":
self.activation_function = nn.GELU
else:
raise ValueError(f"Unknown activation function {activation_function}")
......@@ -383,28 +404,28 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.pos_embed = None
# encoder
encoder_hidden_dim = self.embed_dim
current_dim = self.in_chans
encoder_modules = []
for i in range(self.encoder_layers):
encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True))
encoder_modules.append(self.activation_function())
current_dim = encoder_hidden_dim
encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False))
self.encoder = nn.Sequential(*encoder_modules)
encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
encoder = MLP(in_features = self.in_chans,
out_features = self.embed_dim,
hidden_features = encoder_hidden_dim,
act_layer = self.activation_function,
drop_rate = drop_rate,
checkpointing = False)
self.encoder = encoder
# self.encoder = nn.Sequential(encoder, norm_layer0())
# prepare the spectral transform
if self.spectral_transform == 'sht':
if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float()
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
elif self.spectral_transform == 'fft':
elif self.spectral_transform == "fft":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
......@@ -415,7 +436,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise(ValueError('Unknown spectral transform'))
raise(ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
......@@ -430,11 +451,11 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
outer_skip = 'identity'
if first_layer:
norm_layer = (norm_layer0, norm_layer1)
norm_layer = norm_layer1
elif last_layer:
norm_layer = (norm_layer1, norm_layer0)
norm_layer = norm_layer0
else:
norm_layer = (norm_layer1, norm_layer1)
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(forward_transform,
inverse_transform,
......@@ -451,6 +472,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
inner_skip = inner_skip,
outer_skip = outer_skip,
use_mlp = use_mlp,
lr_scale_exponent = self.lr_scale_exponent,
factorization = self.factorization,
separable = self.separable,
rank = self.rank,
......@@ -460,15 +482,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.blocks.append(block)
# decoder
decoder_hidden_dim = self.embed_dim
current_dim = self.embed_dim + self.big_skip*self.in_chans
decoder_modules = []
for i in range(self.encoder_layers):
decoder_modules.append(nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True))
decoder_modules.append(self.activation_function())
current_dim = decoder_hidden_dim
decoder_modules.append(nn.Conv2d(current_dim, self.out_chans, 1, bias=False))
self.decoder = nn.Sequential(*decoder_modules)
encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
out_features = self.out_chans,
hidden_features = encoder_hidden_dim,
act_layer = self.activation_function,
drop_rate = drop_rate,
checkpointing = False)
# trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
......@@ -482,7 +502,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
return {"pos_embed", "cls_token"}
def forward_features(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment