Commit 9dc07e9b authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding local SNO architecture and updating trraining script for SWE prediction...

adding local SNO architecture and updating trraining script for SWE prediction with powerspectrum computation
parent b91f517c
This diff is collapsed.
...@@ -31,3 +31,4 @@ ...@@ -31,3 +31,4 @@
from .utils.pde_dataset import PdeDataset from .utils.pde_dataset import PdeDataset
from .models.sfno import SphericalFourierNeuralOperatorNet from .models.sfno import SphericalFourierNeuralOperatorNet
from .models.local_sfno import LocalSphericalNeuralOperatorNet
This diff is collapsed.
...@@ -50,64 +50,64 @@ class SpectralFilterLayer(nn.Module): ...@@ -50,64 +50,64 @@ class SpectralFilterLayer(nn.Module):
inverse_transform, inverse_transform,
input_dim, input_dim,
output_dim, output_dim,
gain = 2., gain=2.0,
operator_type = "diagonal", operator_type="diagonal",
hidden_size_factor = 2, hidden_size_factor=2,
factorization = None, factorization=None,
separable = False, separable=False,
rank = 1e-2, rank=1e-2,
bias = True): bias=True,
):
super(SpectralFilterLayer, self).__init__() super(SpectralFilterLayer, self).__init__()
if factorization is None: if factorization is None:
self.filter = SpectralConvS2(forward_transform, self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
inverse_transform,
input_dim,
output_dim,
gain = gain,
operator_type = operator_type,
bias = bias)
elif factorization is not None: elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform, self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform, inverse_transform,
input_dim, input_dim,
output_dim, output_dim,
gain = gain, gain=gain,
operator_type = operator_type, operator_type=operator_type,
rank = rank, rank=rank,
factorization = factorization, factorization=factorization,
separable = separable, separable=separable,
bias = bias) bias=bias,
)
else: else:
raise(NotImplementedError) raise (NotImplementedError)
def forward(self, x): def forward(self, x):
return self.filter(x) return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module): class SphericalFourierNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
""" """
def __init__( def __init__(
self, self,
forward_transform, forward_transform,
inverse_transform, inverse_transform,
input_dim, input_dim,
output_dim, output_dim,
operator_type = "driscoll-healy", operator_type="driscoll-healy",
mlp_ratio = 2., mlp_ratio=2.0,
drop_rate = 0., drop_rate=0.0,
drop_path = 0., drop_path=0.0,
act_layer = nn.ReLU, act_layer=nn.ReLU,
norm_layer = nn.Identity, norm_layer=nn.Identity,
factorization = None, factorization=None,
separable = False, separable=False,
rank = 128, rank=128,
inner_skip = "linear", inner_skip="linear",
outer_skip = None, outer_skip=None,
use_mlp = True): use_mlp=True,
):
super(SphericalFourierNeuralOperatorBlock, self).__init__() super(SphericalFourierNeuralOperatorBlock, self).__init__()
if act_layer == nn.Identity: if act_layer == nn.Identity:
...@@ -119,21 +119,23 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -119,21 +119,23 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
gain_factor /= 2.0 gain_factor /= 2.0
# convolution layer # convolution layer
self.filter = SpectralFilterLayer(forward_transform, self.filter = SpectralFilterLayer(
forward_transform,
inverse_transform, inverse_transform,
input_dim, input_dim,
output_dim, output_dim,
gain = gain_factor, gain=gain_factor,
operator_type = operator_type, operator_type=operator_type,
hidden_size_factor = mlp_ratio, hidden_size_factor=mlp_ratio,
factorization = factorization, factorization=factorization,
separable = separable, separable=separable,
rank = rank, rank=rank,
bias = True) bias=True,
)
if inner_skip == "linear": if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1) self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim)) nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif inner_skip == "identity": elif inner_skip == "identity":
assert input_dim == output_dim assert input_dim == output_dim
self.inner_skip = nn.Identity() self.inner_skip = nn.Identity()
...@@ -148,25 +150,21 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -148,25 +150,21 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
self.norm0 = norm_layer() self.norm0 = norm_layer()
# dropout # dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
gain_factor = 1.0 gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity": if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2. gain_factor /= 2.0
if use_mlp == True: if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio) mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = output_dim, self.mlp = MLP(
out_features = input_dim, in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
hidden_features = mlp_hidden_dim, )
act_layer = act_layer,
drop_rate = drop_rate,
checkpointing = False,
gain = gain_factor)
if outer_skip == "linear": if outer_skip == "linear":
self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1) self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim)) torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif outer_skip == "identity": elif outer_skip == "identity":
assert input_dim == output_dim assert input_dim == output_dim
self.outer_skip = nn.Identity() self.outer_skip = nn.Identity()
...@@ -178,17 +176,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -178,17 +176,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
# second normalisation layer # second normalisation layer
self.norm1 = norm_layer() self.norm1 = norm_layer()
# def init_weights(self, scale):
# if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d):
# gain_factor = 1.
# scale = (gain_factor / embed_dim)**0.5
# nn.init.normal_(self.inner_skip.weight, mean=0., std=scale)
# self.filter.filter.init_weights(scale)
# else:
# gain_factor = 2.
# scale = (gain_factor / embed_dim)**0.5
# self.filter.filter.init_weights(scale)
def forward(self, x): def forward(self, x):
x, residual = self.filter(x) x, residual = self.filter(x)
...@@ -213,6 +200,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -213,6 +200,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
return x return x
class SphericalFourierNeuralOperatorNet(nn.Module): class SphericalFourierNeuralOperatorNet(nn.Module):
""" """
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO, SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
...@@ -281,29 +269,30 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -281,29 +269,30 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
spectral_transform = "sht", spectral_transform="sht",
operator_type = "driscoll-healy", operator_type="driscoll-healy",
img_size = (128, 256), img_size=(128, 256),
grid = "equiangular", grid="equiangular",
scale_factor = 3, scale_factor=3,
in_chans = 3, in_chans=3,
out_chans = 3, out_chans=3,
embed_dim = 256, embed_dim=256,
num_layers = 4, num_layers=4,
activation_function = "relu", activation_function="relu",
encoder_layers = 1, encoder_layers=1,
use_mlp = True, use_mlp=True,
mlp_ratio = 2., mlp_ratio=2.0,
drop_rate = 0., drop_rate=0.0,
drop_path_rate = 0., drop_path_rate=0.0,
normalization_layer = "none", normalization_layer="none",
hard_thresholding_fraction = 1.0, hard_thresholding_fraction=1.0,
use_complex_kernels = True, use_complex_kernels=True,
big_skip = False, big_skip=False,
factorization = None, factorization=None,
separable = False, separable=False,
rank = 128, rank=128,
pos_embed = False): pos_embed=False,
):
super(SphericalFourierNeuralOperatorNet, self).__init__() super(SphericalFourierNeuralOperatorNet, self).__init__()
...@@ -322,7 +311,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -322,7 +311,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.big_skip = big_skip self.big_skip = big_skip
self.factorization = factorization self.factorization = factorization
self.separable = separable, self.separable = (separable,)
self.rank = rank self.rank = rank
# activation function # activation function
...@@ -341,7 +330,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -341,7 +330,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.w = self.img_size[1] // scale_factor self.w = self.img_size[1] // scale_factor
# dropout # dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity() self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer # pick norm layer
...@@ -357,7 +346,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -357,7 +346,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else: else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.") raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed==True: if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1])) self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0) nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat": elif pos_embed == "lat":
...@@ -369,35 +358,24 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -369,35 +358,24 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else: else:
self.pos_embed = None self.pos_embed = None
# # encoder
# encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# encoder = MLP(in_features = self.in_chans,
# out_features = self.embed_dim,
# hidden_features = encoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# self.encoder = encoder
# construct an encoder with num_encoder_layers # construct an encoder with num_encoder_layers
num_encoder_layers = 1 num_encoder_layers = 1
encoder_hidden_dim = int(self.embed_dim * mlp_ratio) encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.in_chans current_dim = self.in_chans
encoder_layers = [] encoder_layers = []
for l in range(num_encoder_layers-1): for l in range(num_encoder_layers - 1):
fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True) fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
# initialize the weights correctly # initialize the weights correctly
scale = math.sqrt(2. / current_dim) scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0) nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc) encoder_layers.append(fc)
encoder_layers.append(self.activation_function()) encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False) fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
scale = math.sqrt(1. / current_dim) scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0) nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc) encoder_layers.append(fc)
...@@ -407,7 +385,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -407,7 +385,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
if self.spectral_transform == "sht": if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w//2 * self.hard_thresholding_fraction) modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon) modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float() self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
...@@ -426,13 +404,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -426,13 +404,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else: else:
raise(ValueError("Unknown spectral transform")) raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
for i in range(self.num_layers): for i in range(self.num_layers):
first_layer = i == 0 first_layer = i == 0
last_layer = i == self.num_layers-1 last_layer = i == self.num_layers - 1
forward_transform = self.trans_down if first_layer else self.trans forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans inverse_transform = self.itrans_up if last_layer else self.itrans
...@@ -447,52 +425,45 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -447,52 +425,45 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else: else:
norm_layer = norm_layer1 norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(forward_transform, block = SphericalFourierNeuralOperatorBlock(
forward_transform,
inverse_transform, inverse_transform,
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
operator_type = self.operator_type, operator_type=self.operator_type,
mlp_ratio = mlp_ratio, mlp_ratio=mlp_ratio,
drop_rate = drop_rate, drop_rate=drop_rate,
drop_path = dpr[i], drop_path=dpr[i],
act_layer = self.activation_function, act_layer=self.activation_function,
norm_layer = norm_layer, norm_layer=norm_layer,
inner_skip = inner_skip, inner_skip=inner_skip,
outer_skip = outer_skip, outer_skip=outer_skip,
use_mlp = use_mlp, use_mlp=use_mlp,
factorization = self.factorization, factorization=self.factorization,
separable = self.separable, separable=self.separable,
rank = self.rank) rank=self.rank,
)
self.blocks.append(block) self.blocks.append(block)
# # decoder
# decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
# out_features = self.out_chans,
# hidden_features = decoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# construct an decoder with num_decoder_layers # construct an decoder with num_decoder_layers
num_decoder_layers = 1 num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio) decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip*self.in_chans current_dim = self.embed_dim + self.big_skip * self.in_chans
decoder_layers = [] decoder_layers = []
for l in range(num_decoder_layers-1): for l in range(num_decoder_layers - 1):
fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True) fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
# initialize the weights correctly # initialize the weights correctly
scale = math.sqrt(2. / current_dim) scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0) nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc) decoder_layers.append(fc)
decoder_layers.append(self.activation_function()) decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False) fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
scale = math.sqrt(1. / current_dim) scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale) nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None: if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0) nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc) decoder_layers.append(fc)
...@@ -529,5 +500,3 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -529,5 +500,3 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
x = self.decoder(x) x = self.decoder(x)
return x return x
...@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module):
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed # mach number relative to wave speed
llimit = mlimit = 80 llimit = mlimit = 120
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype) # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment