Unverified Commit 7fb0c483 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/sfno update (#10)

* reworked SFNO example

* updated changelog
parent cec07d7a
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
### v0.6.3 ### v0.6.3
* Adding gradient check in unit tests * Adding gradient check in unit tests
* Updated SFNO example
### v0.6.2 ### v0.6.2
......
...@@ -334,7 +334,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -334,7 +334,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
# SFNO models # SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon), models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector') num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy')
models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon), models['sfno_sc3_layer4_edim256_real'] = partial(SFNO, spectral_transform='sht', filter_type='non-linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal') num_layers=4, scale_factor=3, embed_dim=256, complex_activation = 'real', operator_type='diagonal')
# FNO models # FNO models
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -36,32 +36,27 @@ Contains complex contractions wrapped into jit for harmonic layers ...@@ -36,32 +36,27 @@ Contains complex contractions wrapped into jit for harmonic layers
""" """
@torch.jit.script @torch.jit.script
def compl_contract2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def contract_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,kixr->srbkx", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_contract2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a) ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b) bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkx", ac, bc) res = torch.einsum("bixy,kixy->bkxy", ac, bc)
return torch.view_as_real(res) return torch.view_as_real(res)
@torch.jit.script @torch.jit.script
def compl_contract_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bins,kinr->srbkn", a, b) ac = torch.view_as_complex(a)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1) bc = torch.view_as_complex(b)
return res res = torch.einsum("bixy,kix->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script @torch.jit.script
def compl_contract_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def contract_blockdiag(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a) ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b) bc = torch.view_as_complex(b)
res = torch.einsum("bin,kin->bkn", ac, bc) res = torch.einsum("bixy,kixyz->bkxz", ac, bc)
return torch.view_as_real(res) return torch.view_as_real(res)
# Helper routines for spherical MLPs # Helper routines for the non-linear FNOs (Attention-like)
@torch.jit.script @torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b) tmp = torch.einsum("bixs,ior->srbox", a, b)
...@@ -124,18 +119,3 @@ def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ...@@ -124,18 +119,3 @@ def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c return compl_mul2d_fwd_c(a, b) + c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
...@@ -59,7 +59,7 @@ def _contract_dense(x, weight, separable=False, operator_type='diagonal'): ...@@ -59,7 +59,7 @@ def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal': elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1]) weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2] out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector': elif operator_type == 'driscoll-healy':
weight_syms.pop() weight_syms.pop()
else: else:
raise ValueError(f"Unkonw operator type {operator_type}") raise ValueError(f"Unkonw operator type {operator_type}")
...@@ -92,7 +92,7 @@ def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'): ...@@ -92,7 +92,7 @@ def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal': elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2] out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym] factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'vector': elif operator_type == 'driscoll-healy':
factor_syms.pop() factor_syms.pop()
else: else:
raise ValueError(f"Unkonw operator type {operator_type}") raise ValueError(f"Unkonw operator type {operator_type}")
...@@ -148,7 +148,7 @@ def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'): ...@@ -148,7 +148,7 @@ def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
elif operator_type == 'block-diagonal': elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1]) weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2] out_syms[-1] = weight_syms[-2]
elif operator_type == 'vector': elif operator_type == 'driscoll-healy':
weight_syms.pop() weight_syms.pop()
else: else:
raise ValueError(f"Unkonw operator type {operator_type}") raise ValueError(f"Unkonw operator type {operator_type}")
......
...@@ -40,8 +40,6 @@ from torch_harmonics import * ...@@ -40,8 +40,6 @@ from torch_harmonics import *
from .contractions import * from .contractions import *
from .activations import * from .activations import *
from .factorizations import get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations # # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl # import tensorly as tl
# from tensorly.plugins import use_opt_einsum # from tensorly.plugins import use_opt_einsum
...@@ -207,7 +205,7 @@ class InverseRealFFT2(nn.Module): ...@@ -207,7 +205,7 @@ class InverseRealFFT2(nn.Module):
def forward(self, x): def forward(self, x):
return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho") return torch.fft.irfft2(x, dim=(-2, -1), s=(self.nlat, self.nlon), norm="ortho")
class SpectralConvS2(nn.Module): class SpectralConvS2(nn.Module):
""" """
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2 Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
...@@ -221,7 +219,95 @@ class SpectralConvS2(nn.Module): ...@@ -221,7 +219,95 @@ class SpectralConvS2(nn.Module):
in_channels, in_channels,
out_channels, out_channels,
scale = 'auto', scale = 'auto',
operator_type = 'diagonal', operator_type = 'driscoll-healy',
lr_scale_exponent = 0,
bias = False):
super(SpectralConvS2, self).__init__()
if scale == 'auto':
scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# remember factorization details
self.operator_type = operator_type
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [in_channels, out_channels]
if self.operator_type == 'diagonal':
weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract
elif self.operator_type == 'block-diagonal':
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract
elif self.operator_type == 'driscoll-healy':
weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))
# rescale the learning rate for better training of spectral parameters
lr_scale = (torch.arange(self.modes_lat)+1).reshape(-1, 1)**(lr_scale_exponent)
self.register_buffer("lr_scale", lr_scale)
# self.weight.register_hook(lambda grad: self.lr_scale*grad)
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(scale * torch.randn(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with amp.autocast(enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = torch.view_as_real(x)
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with amp.autocast(enabled=False):
x = self.inverse_transform(x)
if hasattr(self, 'bias'):
x = x + self.bias
x = x.type(dtype)
return x, residual
class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
scale = 'auto',
operator_type = 'driscoll-healy',
rank = 0.2, rank = 0.2,
factorization = None, factorization = None,
separable = False, separable = False,
...@@ -231,7 +317,7 @@ class SpectralConvS2(nn.Module): ...@@ -231,7 +317,7 @@ class SpectralConvS2(nn.Module):
super(SpectralConvS2, self).__init__() super(SpectralConvS2, self).__init__()
if scale == 'auto': if scale == 'auto':
scale = (1 / (in_channels * out_channels)) scale = (2 / in_channels)**0.5
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
...@@ -266,7 +352,7 @@ class SpectralConvS2(nn.Module): ...@@ -266,7 +352,7 @@ class SpectralConvS2(nn.Module):
weight_shape += [self.modes_lat, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == 'block-diagonal': elif self.operator_type == 'block-diagonal':
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == 'vector': elif self.operator_type == 'driscoll-healy':
weight_shape += [self.modes_lat] weight_shape += [self.modes_lat]
else: else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
...@@ -278,6 +364,8 @@ class SpectralConvS2(nn.Module): ...@@ -278,6 +364,8 @@ class SpectralConvS2(nn.Module):
# initialization of weights # initialization of weights
self.weight.normal_(0, scale) self.weight.normal_(0, scale)
# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable) self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias: if bias:
...@@ -289,7 +377,6 @@ class SpectralConvS2(nn.Module): ...@@ -289,7 +377,6 @@ class SpectralConvS2(nn.Module):
dtype = x.dtype dtype = x.dtype
x = x.float() x = x.float()
residual = x residual = x
B, C, H, W = x.shape
with amp.autocast(enabled=False): with amp.autocast(enabled=False):
x = self.forward_transform(x) x = self.forward_transform(x)
...@@ -467,7 +554,7 @@ class SpectralAttentionS2(nn.Module): ...@@ -467,7 +554,7 @@ class SpectralAttentionS2(nn.Module):
for l in range(0, self.spectral_layers): for l in range(0, self.spectral_layers):
self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale)) self.activations.append(ComplexReLU(mode=complex_activation, bias_shape=(hidden_size, 1, 1), scale=self.scale))
elif operator_type == 'vector': elif operator_type == 'driscoll-healy':
self.mul_add_handle = compl_exp_muladd2d_fwd self.mul_add_handle = compl_exp_muladd2d_fwd
self.mul_handle = compl_exp_mul2d_fwd self.mul_handle = compl_exp_mul2d_fwd
......
...@@ -48,20 +48,21 @@ class SpectralFilterLayer(nn.Module): ...@@ -48,20 +48,21 @@ class SpectralFilterLayer(nn.Module):
forward_transform, forward_transform,
inverse_transform, inverse_transform,
embed_dim, embed_dim,
filter_type = 'non-linear', filter_type = "non-linear",
operator_type = 'diagonal', operator_type = "diagonal",
sparsity_threshold = 0.0, sparsity_threshold = 0.0,
use_complex_kernels = True, use_complex_kernels = True,
hidden_size_factor = 2, hidden_size_factor = 2,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 1e-2, rank = 1e-2,
complex_activation = 'real', complex_activation = "real",
spectral_layers = 1, spectral_layers = 1,
drop_rate = 0): drop_rate = 0):
super(SpectralFilterLayer, self).__init__() super(SpectralFilterLayer, self).__init__()
if filter_type == 'non-linear' and isinstance(forward_transform, RealSHT): if filter_type == "non-linear" and isinstance(forward_transform, RealSHT):
self.filter = SpectralAttentionS2(forward_transform, self.filter = SpectralAttentionS2(forward_transform,
inverse_transform, inverse_transform,
embed_dim, embed_dim,
...@@ -73,7 +74,7 @@ class SpectralFilterLayer(nn.Module): ...@@ -73,7 +74,7 @@ class SpectralFilterLayer(nn.Module):
drop_rate = drop_rate, drop_rate = drop_rate,
bias = False) bias = False)
elif filter_type == 'non-linear' and isinstance(forward_transform, RealFFT2): elif filter_type == "non-linear" and isinstance(forward_transform, RealFFT2):
self.filter = SpectralAttention2d(forward_transform, self.filter = SpectralAttention2d(forward_transform,
inverse_transform, inverse_transform,
embed_dim, embed_dim,
...@@ -85,16 +86,25 @@ class SpectralFilterLayer(nn.Module): ...@@ -85,16 +86,25 @@ class SpectralFilterLayer(nn.Module):
drop_rate = drop_rate, drop_rate = drop_rate,
bias = False) bias = False)
elif filter_type == 'linear': elif filter_type == "linear" and factorization is None:
self.filter = SpectralConvS2(forward_transform, self.filter = SpectralConvS2(forward_transform,
inverse_transform, inverse_transform,
embed_dim, embed_dim,
embed_dim, embed_dim,
operator_type = operator_type, operator_type = operator_type,
rank = rank, lr_scale_exponent = lr_scale_exponent,
factorization = factorization,
separable = separable,
bias = True) bias = True)
elif filter_type == "linear" and factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform,
embed_dim,
embed_dim,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
bias = True)
else: else:
raise(NotImplementedError) raise(NotImplementedError)
...@@ -111,29 +121,27 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -111,29 +121,27 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
forward_transform, forward_transform,
inverse_transform, inverse_transform,
embed_dim, embed_dim,
filter_type = 'non-linear', filter_type = "non-linear",
operator_type = 'diagonal', operator_type = "driscoll-healy",
mlp_ratio = 2., mlp_ratio = 2.,
drop_rate = 0., drop_rate = 0.,
drop_path = 0., drop_path = 0.,
act_layer = nn.GELU, act_layer = nn.GELU,
norm_layer = (nn.LayerNorm, nn.LayerNorm), norm_layer = nn.Identity,
sparsity_threshold = 0.0, sparsity_threshold = 0.0,
use_complex_kernels = True, use_complex_kernels = True,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 128, rank = 128,
inner_skip = 'linear', inner_skip = "linear",
outer_skip = None, # None, nn.linear or nn.Identity outer_skip = None,
concat_skip = False, concat_skip = False,
use_mlp = True, use_mlp = True,
complex_activation = 'real', complex_activation = "real",
spectral_layers = 3): spectral_layers = 3):
super(SphericalFourierNeuralOperatorBlock, self).__init__() super(SphericalFourierNeuralOperatorBlock, self).__init__()
# norm layer
self.norm0 = norm_layer[0]() #((h,w))
# convolution layer # convolution layer
self.filter = SpectralFilterLayer(forward_transform, self.filter = SpectralFilterLayer(forward_transform,
inverse_transform, inverse_transform,
...@@ -143,6 +151,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -143,6 +151,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
sparsity_threshold = sparsity_threshold, sparsity_threshold = sparsity_threshold,
use_complex_kernels = use_complex_kernels, use_complex_kernels = use_complex_kernels,
hidden_size_factor = mlp_ratio, hidden_size_factor = mlp_ratio,
lr_scale_exponent = lr_scale_exponent,
factorization = factorization, factorization = factorization,
separable = separable, separable = separable,
rank = rank, rank = rank,
...@@ -150,24 +159,28 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -150,24 +159,28 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
spectral_layers = spectral_layers, spectral_layers = spectral_layers,
drop_rate = drop_rate) drop_rate = drop_rate)
if inner_skip == 'linear': if inner_skip == "linear":
self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif inner_skip == 'identity': elif inner_skip == "identity":
self.inner_skip = nn.Identity() self.inner_skip = nn.Identity()
elif inner_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
self.concat_skip = concat_skip self.concat_skip = concat_skip
if concat_skip and inner_skip is not None: if concat_skip and inner_skip is not None:
self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False) self.inner_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
if filter_type == 'linear' or filter_type == 'local': if filter_type == "linear":
self.act_layer = act_layer() self.act_layer = act_layer()
# first normalisation layer
self.norm0 = norm_layer()
# dropout # dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# norm layer
self.norm1 = norm_layer[1]() #((h,w))
if use_mlp == True: if use_mlp == True:
mlp_hidden_dim = int(embed_dim * mlp_ratio) mlp_hidden_dim = int(embed_dim * mlp_ratio)
...@@ -177,44 +190,51 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -177,44 +190,51 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
drop_rate = drop_rate, drop_rate = drop_rate,
checkpointing = False) checkpointing = False)
if outer_skip == 'linear': if outer_skip == "linear":
self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1)
elif outer_skip == 'identity': elif outer_skip == "identity":
self.outer_skip = nn.Identity() self.outer_skip = nn.Identity()
elif outer_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
if concat_skip and outer_skip is not None: if concat_skip and outer_skip is not None:
self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False) self.outer_skip_conv = nn.Conv2d(2*embed_dim, embed_dim, 1, bias=False)
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x): def forward(self, x):
x = self.norm0(x)
x, residual = self.filter(x) x, residual = self.filter(x)
if hasattr(self, 'inner_skip'): if hasattr(self, "inner_skip"):
if self.concat_skip: if self.concat_skip:
x = torch.cat((x, self.inner_skip(residual)), dim=1) x = torch.cat((x, self.inner_skip(residual)), dim=1)
x = self.inner_skip_conv(x) x = self.inner_skip_conv(x)
else: else:
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
if hasattr(self, 'act_layer'): if hasattr(self, "act_layer"):
x = self.act_layer(x) x = self.act_layer(x)
x = self.norm1(x) x = self.norm0(x)
if hasattr(self, 'mlp'): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
x = self.drop_path(x) x = self.drop_path(x)
if hasattr(self, 'outer_skip'): if hasattr(self, "outer_skip"):
if self.concat_skip: if self.concat_skip:
x = torch.cat((x, self.outer_skip(residual)), dim=1) x = torch.cat((x, self.outer_skip(residual)), dim=1)
x = self.outer_skip_conv(x) x = self.outer_skip_conv(x)
else: else:
x = x + self.outer_skip(residual) x = x + self.outer_skip(residual)
x = self.norm1(x)
return x return x
class SphericalFourierNeuralOperatorNet(nn.Module): class SphericalFourierNeuralOperatorNet(nn.Module):
...@@ -229,7 +249,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -229,7 +249,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
spectral_transform : str, optional spectral_transform : str, optional
Type of spectral transformation to use, by default "sht" Type of spectral transformation to use, by default "sht"
operator_type : str, optional operator_type : str, optional
Type of operator to use ('vector', 'diagonal'), by default "vector" Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional img_shape : tuple, optional
Shape of the input channels, by default (128, 256) Shape of the input channels, by default (128, 256)
scale_factor : int, optional scale_factor : int, optional
...@@ -247,7 +267,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -247,7 +267,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers : int, optional encoder_layers : int, optional
Number of layers in the encoder, by default 1 Number of layers in the encoder, by default 1
use_mlp : int, optional use_mlp : int, optional
Whether to use MLP, by default True Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0 Ratio of MLP to use, by default 2.0
drop_rate : float, optional drop_rate : float, optional
...@@ -266,6 +286,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -266,6 +286,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
rank : float, optional rank : float, optional
Rank of the approximation, by default 1.0 Rank of the approximation, by default 1.0
lr_scale_exponent : float, optional
exponential rescaling of spectral coefficients, by default 0.0 (no rescaling)
factorization : Any, optional factorization : Any, optional
Type of factorization to use, by default None Type of factorization to use, by default None
separable : bool, optional separable : bool, optional
...@@ -287,10 +309,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -287,10 +309,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
... in_chans=2, ... in_chans=2,
... out_chans=2, ... out_chans=2,
... embed_dim=16, ... embed_dim=16,
... num_layers=2, ... num_layers=4,
... encoder_layers=1,
... num_blocks=4,
... spectral_layers=2,
... use_mlp=True,) ... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape >>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
...@@ -298,30 +317,31 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -298,30 +317,31 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
filter_type = 'linear', filter_type = "linear",
spectral_transform = 'sht', spectral_transform = "sht",
operator_type = 'vector', operator_type = "driscoll-healy",
img_size = (128, 256), img_size = (128, 256),
scale_factor = 3, scale_factor = 3,
in_chans = 3, in_chans = 3,
out_chans = 3, out_chans = 3,
embed_dim = 256, embed_dim = 256,
num_layers = 4, num_layers = 4,
activation_function = 'gelu', activation_function = "gelu",
encoder_layers = 1, encoder_layers = 1,
use_mlp = True, use_mlp = True,
mlp_ratio = 2., mlp_ratio = 2.,
drop_rate = 0., drop_rate = 0.,
drop_path_rate = 0., drop_path_rate = 0.,
sparsity_threshold = 0.0, sparsity_threshold = 0.0,
normalization_layer = 'instance_norm', normalization_layer = "none",
hard_thresholding_fraction = 1.0, hard_thresholding_fraction = 1.0,
use_complex_kernels = True, use_complex_kernels = True,
big_skip = True, big_skip = True,
lr_scale_exponent = 0,
factorization = None, factorization = None,
separable = False, separable = False,
rank = 128, rank = 128,
complex_activation = 'real', complex_activation = "real",
spectral_layers = 2, spectral_layers = 2,
pos_embed = True): pos_embed = True):
...@@ -342,6 +362,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -342,6 +362,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.big_skip = big_skip self.big_skip = big_skip
self.lr_scale_exponent = lr_scale_exponent
self.factorization = factorization self.factorization = factorization
self.separable = separable, self.separable = separable,
self.rank = rank self.rank = rank
...@@ -349,9 +370,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -349,9 +370,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.spectral_layers = spectral_layers self.spectral_layers = spectral_layers
# activation function # activation function
if activation_function == 'relu': if activation_function == "relu":
self.activation_function = nn.ReLU self.activation_function = nn.ReLU
elif activation_function == 'gelu': elif activation_function == "gelu":
self.activation_function = nn.GELU self.activation_function = nn.GELU
else: else:
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
...@@ -383,28 +404,28 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -383,28 +404,28 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.pos_embed = None self.pos_embed = None
# encoder # encoder
encoder_hidden_dim = self.embed_dim encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.in_chans encoder = MLP(in_features = self.in_chans,
encoder_modules = [] out_features = self.embed_dim,
for i in range(self.encoder_layers): hidden_features = encoder_hidden_dim,
encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)) act_layer = self.activation_function,
encoder_modules.append(self.activation_function()) drop_rate = drop_rate,
current_dim = encoder_hidden_dim checkpointing = False)
encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)) self.encoder = encoder
self.encoder = nn.Sequential(*encoder_modules) # self.encoder = nn.Sequential(encoder, norm_layer0())
# prepare the spectral transform # prepare the spectral transform
if self.spectral_transform == 'sht': if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float() self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid='equiangular').float() self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid="equiangular").float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float() self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid='legendre-gauss').float() self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
elif self.spectral_transform == 'fft': elif self.spectral_transform == "fft":
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
...@@ -415,7 +436,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -415,7 +436,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else: else:
raise(ValueError('Unknown spectral transform')) raise(ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -430,11 +451,11 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -430,11 +451,11 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
outer_skip = 'identity' outer_skip = 'identity'
if first_layer: if first_layer:
norm_layer = (norm_layer0, norm_layer1) norm_layer = norm_layer1
elif last_layer: elif last_layer:
norm_layer = (norm_layer1, norm_layer0) norm_layer = norm_layer0
else: else:
norm_layer = (norm_layer1, norm_layer1) norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(forward_transform, block = SphericalFourierNeuralOperatorBlock(forward_transform,
inverse_transform, inverse_transform,
...@@ -451,6 +472,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -451,6 +472,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
inner_skip = inner_skip, inner_skip = inner_skip,
outer_skip = outer_skip, outer_skip = outer_skip,
use_mlp = use_mlp, use_mlp = use_mlp,
lr_scale_exponent = self.lr_scale_exponent,
factorization = self.factorization, factorization = self.factorization,
separable = self.separable, separable = self.separable,
rank = self.rank, rank = self.rank,
...@@ -460,15 +482,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -460,15 +482,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.blocks.append(block) self.blocks.append(block)
# decoder # decoder
decoder_hidden_dim = self.embed_dim encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip*self.in_chans self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
decoder_modules = [] out_features = self.out_chans,
for i in range(self.encoder_layers): hidden_features = encoder_hidden_dim,
decoder_modules.append(nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)) act_layer = self.activation_function,
decoder_modules.append(self.activation_function()) drop_rate = drop_rate,
current_dim = decoder_hidden_dim checkpointing = False)
decoder_modules.append(nn.Conv2d(current_dim, self.out_chans, 1, bias=False))
self.decoder = nn.Sequential(*decoder_modules)
# trunc_normal_(self.pos_embed, std=.02) # trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights) self.apply(self._init_weights)
...@@ -482,7 +502,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -482,7 +502,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {"pos_embed", "cls_token"}
def forward_features(self, x): def forward_features(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment