"git@developer.sourcefind.cn:OpenDAS/autoawq_kernels.git" did not exist on "e90b731a667aa1efae0edc64ac120a07a844ee2c"
Commit 9dc07e9b authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding local SNO architecture and updating trraining script for SWE prediction...

adding local SNO architecture and updating trraining script for SWE prediction with powerspectrum computation
parent b91f517c
This diff is collapsed.
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
......@@ -31,3 +31,4 @@
from .utils.pde_dataset import PdeDataset
from .models.sfno import SphericalFourierNeuralOperatorNet
from .models.local_sfno import LocalSphericalNeuralOperatorNet
This diff is collapsed.
......@@ -50,64 +50,64 @@ class SpectralFilterLayer(nn.Module):
inverse_transform,
input_dim,
output_dim,
gain = 2.,
operator_type = "diagonal",
hidden_size_factor = 2,
factorization = None,
separable = False,
rank = 1e-2,
bias = True):
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain,
operator_type = operator_type,
bias = bias)
self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
bias = bias)
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise(NotImplementedError)
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
operator_type = "driscoll-healy",
mlp_ratio = 2.,
drop_rate = 0.,
drop_path = 0.,
act_layer = nn.ReLU,
norm_layer = nn.Identity,
factorization = None,
separable = False,
rank = 128,
inner_skip = "linear",
outer_skip = None,
use_mlp = True):
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
operator_type="driscoll-healy",
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.ReLU,
norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="linear",
outer_skip=None,
use_mlp=True,
):
super(SphericalFourierNeuralOperatorBlock, self).__init__()
if act_layer == nn.Identity:
......@@ -119,21 +119,23 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
gain_factor /= 2.0
# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain_factor,
operator_type = operator_type,
hidden_size_factor = mlp_ratio,
factorization = factorization,
separable = separable,
rank = rank,
bias = True)
self.filter = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=True,
)
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim))
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif inner_skip == "identity":
assert input_dim == output_dim
self.inner_skip = nn.Identity()
......@@ -148,25 +150,21 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
self.norm0 = norm_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.
gain_factor /= 2.0
if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = output_dim,
out_features = input_dim,
hidden_features = mlp_hidden_dim,
act_layer = act_layer,
drop_rate = drop_rate,
checkpointing = False,
gain = gain_factor)
self.mlp = MLP(
in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
)
if outer_skip == "linear":
self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim))
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif outer_skip == "identity":
assert input_dim == output_dim
self.outer_skip = nn.Identity()
......@@ -178,17 +176,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
# second normalisation layer
self.norm1 = norm_layer()
# def init_weights(self, scale):
# if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d):
# gain_factor = 1.
# scale = (gain_factor / embed_dim)**0.5
# nn.init.normal_(self.inner_skip.weight, mean=0., std=scale)
# self.filter.filter.init_weights(scale)
# else:
# gain_factor = 2.
# scale = (gain_factor / embed_dim)**0.5
# self.filter.filter.init_weights(scale)
def forward(self, x):
x, residual = self.filter(x)
......@@ -213,6 +200,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
return x
class SphericalFourierNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
......@@ -280,30 +268,31 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
"""
def __init__(
self,
spectral_transform = "sht",
operator_type = "driscoll-healy",
img_size = (128, 256),
grid = "equiangular",
scale_factor = 3,
in_chans = 3,
out_chans = 3,
embed_dim = 256,
num_layers = 4,
activation_function = "relu",
encoder_layers = 1,
use_mlp = True,
mlp_ratio = 2.,
drop_rate = 0.,
drop_path_rate = 0.,
normalization_layer = "none",
hard_thresholding_fraction = 1.0,
use_complex_kernels = True,
big_skip = False,
factorization = None,
separable = False,
rank = 128,
pos_embed = False):
self,
spectral_transform="sht",
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="relu",
encoder_layers=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False,
):
super(SphericalFourierNeuralOperatorNet, self).__init__()
......@@ -322,7 +311,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.factorization = factorization
self.separable = separable,
self.separable = (separable,)
self.rank = rank
# activation function
......@@ -341,7 +330,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
......@@ -357,7 +346,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed==True:
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
......@@ -369,35 +358,24 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
self.pos_embed = None
# # encoder
# encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# encoder = MLP(in_features = self.in_chans,
# out_features = self.embed_dim,
# hidden_features = encoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# self.encoder = encoder
# construct an encoder with num_encoder_layers
num_encoder_layers = 1
encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.in_chans
encoder_layers = []
for l in range(num_encoder_layers-1):
for l in range(num_encoder_layers - 1):
fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
# initialize the weights correctly
scale = math.sqrt(2. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
......@@ -407,13 +385,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w//2 * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
elif self.spectral_transform == "fft":
......@@ -421,18 +399,18 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise(ValueError("Unknown spectral transform"))
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers-1
last_layer = i == self.num_layers - 1
forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans
......@@ -447,52 +425,45 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(forward_transform,
inverse_transform,
self.embed_dim,
self.embed_dim,
operator_type = self.operator_type,
mlp_ratio = mlp_ratio,
drop_rate = drop_rate,
drop_path = dpr[i],
act_layer = self.activation_function,
norm_layer = norm_layer,
inner_skip = inner_skip,
outer_skip = outer_skip,
use_mlp = use_mlp,
factorization = self.factorization,
separable = self.separable,
rank = self.rank)
block = SphericalFourierNeuralOperatorBlock(
forward_transform,
inverse_transform,
self.embed_dim,
self.embed_dim,
operator_type=self.operator_type,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
)
self.blocks.append(block)
# # decoder
# decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
# out_features = self.out_chans,
# hidden_features = decoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# construct an decoder with num_decoder_layers
num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip*self.in_chans
current_dim = self.embed_dim + self.big_skip * self.in_chans
decoder_layers = []
for l in range(num_decoder_layers-1):
for l in range(num_decoder_layers - 1):
fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
# initialize the weights correctly
scale = math.sqrt(2. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
......@@ -529,5 +500,3 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
x = self.decoder(x)
return x
......@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module):
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed
llimit = mlimit = 80
llimit = mlimit = 120
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment