Commit 9dc07e9b authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding local SNO architecture and updating trraining script for SWE prediction...

adding local SNO architecture and updating trraining script for SWE prediction with powerspectrum computation
parent b91f517c
......@@ -45,13 +45,16 @@ import pandas as pd
import matplotlib.pyplot as plt
from torch_harmonics.examples.sfno import PdeDataset
from torch_harmonics import RealSHT
# wandb logging
import wandb
wandb.login()
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
if relative:
loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)
......@@ -61,18 +64,19 @@ def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
return loss
def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
loss = torch.sum(norm2, dim=(-1, -2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
loss = loss / tar_norm2
if not squared:
......@@ -81,25 +85,26 @@ def spectral_l2loss_sphere(solver, prd, tar, relative=False, squared=True):
return loss
def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
coeffs = spectral_weights * coeffs
norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
loss = torch.sum(norm2, dim=(-1,-2))
loss = torch.sum(norm2, dim=(-1, -2))
if relative:
tar_coeffs = torch.view_as_real(solver.sht(tar))
tar_coeffs = tar_coeffs[..., 0]**2 + tar_coeffs[..., 1]**2
tar_coeffs = tar_coeffs[..., 0] ** 2 + tar_coeffs[..., 1] ** 2
tar_coeffs = spectral_weights * tar_coeffs
tar_norm2 = tar_coeffs[..., :, 0] + 2 * torch.sum(tar_coeffs[..., :, 1:], dim=-1)
tar_norm2 = torch.sum(tar_norm2, dim=(-1,-2))
tar_norm2 = torch.sum(tar_norm2, dim=(-1, -2))
loss = loss / tar_norm2
if not squared:
......@@ -108,20 +113,21 @@ def spectral_loss_sphere(solver, prd, tar, relative=False, squared=True):
return loss
def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
# gradient weighting factors
lmax = solver.sht.lmax
ls = torch.arange(lmax).float()
spectral_weights = (ls*(ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
spectral_weights = (ls * (ls + 1)).reshape(1, 1, -1, 1).to(prd.device)
# compute coefficients
coeffs = torch.view_as_real(solver.sht(prd - tar))
coeffs = coeffs[..., 0]**2 + coeffs[..., 1]**2
coeffs = coeffs[..., 0] ** 2 + coeffs[..., 1] ** 2
h1_coeffs = spectral_weights * coeffs
h1_norm2 = h1_coeffs[..., :, 0] + 2 * torch.sum(h1_coeffs[..., :, 1:], dim=-1)
l2_norm2 = coeffs[..., :, 0] + 2 * torch.sum(coeffs[..., :, 1:], dim=-1)
h1_loss = torch.sum(h1_norm2, dim=(-1,-2))
l2_loss = torch.sum(l2_norm2, dim=(-1,-2))
h1_loss = torch.sum(h1_norm2, dim=(-1, -2))
l2_loss = torch.sum(l2_norm2, dim=(-1, -2))
# strictly speaking this is not exactly h1 loss
if not squared:
......@@ -134,30 +140,24 @@ def h1loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = loss.mean()
return loss
def fluct_l2loss_sphere(solver, prd, tar, inp, relative=False, polar_opt=0):
# compute the weighting factor first
fluct = solver.integrate_grid((tar - inp)**2, dimensionless=True, polar_opt=polar_opt)
fluct = solver.integrate_grid((tar - inp) ** 2, dimensionless=True, polar_opt=polar_opt)
weight = fluct / torch.sum(fluct, dim=-1, keepdim=True)
# weight = weight.reshape(*weight.shape, 1, 1)
loss = weight * solver.integrate_grid((prd - tar)**2, dimensionless=True, polar_opt=polar_opt)
loss = weight * solver.integrate_grid((prd - tar) ** 2, dimensionless=True, polar_opt=polar_opt)
if relative:
loss = loss / (weight * solver.integrate_grid(tar**2, dimensionless=True, polar_opt=polar_opt))
loss = torch.mean(loss)
return loss
# rolls out the FNO and compares to the classical solver
def autoregressive_inference(model,
dataset,
path_root,
nsteps,
autoreg_steps=10,
nskip=1,
plot_channel=0,
nics=20):
def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10, nskip=1, plot_channel=0, nics=50):
model.eval()
......@@ -165,6 +165,10 @@ def autoregressive_inference(model,
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
# accumulation buffers for the power spectrum
prd_mean_coeffs = []
ref_mean_coeffs = []
for iic in range(nics):
ic = dataset.solver.random_initial_condition(mach=0.2)
inp_mean = dataset.inp_mean
......@@ -176,45 +180,69 @@ def autoregressive_inference(model,
# ML model
start_time = time.time()
for i in range(1, autoreg_steps+1):
for i in range(1, autoreg_steps + 1):
# evaluate the ML model
prd = model(prd)
if iic == nics-1 and nskip > 0 and i % nskip == 0:
if iic == nics - 1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_pred_'+str(i//nskip)+'.png')
plt.clf()
plt.savefig(path_root + "_pred_" + str(i // nskip) + ".png")
plt.close()
fno_times[iic] = time.time() - start_time
# classical model
start_time = time.time()
for i in range(1, autoreg_steps+1):
for i in range(1, autoreg_steps + 1):
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
if iic == nics-1 and i % nskip == 0 and nskip > 0:
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
if iic == nics - 1 and i % nskip == 0 and nskip > 0:
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root+'_truth_'+str(i//nskip)+'.png')
plt.clf()
plt.savefig(path_root + "_truth_" + str(i // nskip) + ".png")
plt.close()
nwp_times[iic] = time.time() - start_time
# compute power spectrum and add it to the buffers
prd_coeffs = dataset.solver.sht(prd[0, plot_channel])
ref_coeffs = dataset.solver.sht(ref[plot_channel])
prd_mean_coeffs.append(prd_coeffs)
ref_mean_coeffs.append(ref_coeffs)
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
prd = prd * torch.sqrt(inp_var) + inp_mean
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
# compute the averaged powerspectra of prediction and reference
prd_mean_coeffs = torch.stack(prd_mean_coeffs).abs().pow(2).mean(dim=0)
ref_mean_coeffs = torch.stack(ref_mean_coeffs).abs().pow(2).mean(dim=0)
prd_mean_coeffs[..., 1:] *= 2.0
ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).detach().cpu()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).detach().cpu()
# compute the averaged powerspectrum
fig = plt.figure(figsize=(7.5, 6))
plt.loglog(prd_mean_ps, label="prediction")
plt.loglog(ref_mean_ps, label="reference")
plt.xlabel("$l$")
plt.ylabel("powerspectrum")
plt.legend()
plt.savefig(path_root + "_powerspectrum.png")
plt.close()
return losses, fno_times, nwp_times
# convenience function for logging weights and gradients
def log_weights_and_grads(model, iters=1):
"""
......@@ -225,25 +253,15 @@ def log_weights_and_grads(model, iters=1):
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
weights_dict = {k:v for k,v in model.named_parameters()}
grad_dict = {k:v.grad for k,v in model.named_parameters()}
weights_dict = {k: v for k, v in model.named_parameters()}
grad_dict = {k: v.grad for k, v in model.named_parameters()}
store_dict = {'iteration': iters, 'grads': grad_dict, 'weights': weights_dict}
store_dict = {"iteration": iters, "grads": grad_dict, "weights": weights_dict}
torch.save(store_dict, weights_and_grads_fname)
# training function
def train_model(model,
dataloader,
optimizer,
gscaler,
scheduler=None,
nepochs=20,
nfuture=0,
num_examples=256,
num_valid=8,
loss_fn='l2',
enable_amp=False,
log_grads=0):
def train_model(model, dataloader, optimizer, gscaler, scheduler=None, nepochs=20, nfuture=0, num_examples=256, num_valid=8, loss_fn="l2", enable_amp=False, log_grads=0):
train_start = time.time()
......@@ -255,7 +273,7 @@ def train_model(model,
# time each epoch
epoch_start = time.time()
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_initial_condition("random")
dataloader.dataset.set_num_examples(num_examples)
# get the solver for its convenience functions
......@@ -273,18 +291,18 @@ def train_model(model,
for _ in range(nfuture):
prd = model(prd)
if loss_fn == 'l2':
if loss_fn == "l2":
loss = l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral l2':
elif loss_fn == "spectral l2":
loss = spectral_l2loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'h1':
elif loss_fn == "h1":
loss = h1loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'spectral':
elif loss_fn == "spectral":
loss = spectral_loss_sphere(solver, prd, tar, relative=False)
elif loss_fn == 'fluct':
elif loss_fn == "fluct":
loss = fluct_l2loss_sphere(solver, prd, tar, inp, relative=True)
else:
raise NotImplementedError(f'Unknown loss function {loss_fn}')
raise NotImplementedError(f"Unknown loss function {loss_fn}")
acc_loss += loss.item() * inp.size(0)
......@@ -301,7 +319,7 @@ def train_model(model,
acc_loss = acc_loss / len(dataloader.dataset)
dataloader.dataset.set_initial_condition('random')
dataloader.dataset.set_initial_condition("random")
dataloader.dataset.set_num_examples(num_valid)
# perform validation
......@@ -323,23 +341,23 @@ def train_model(model,
epoch_time = time.time() - epoch_start
print(f'--------------------------------------------------------------------------------')
print(f'Epoch {epoch} summary:')
print(f'time taken: {epoch_time}')
print(f'accumulated training loss: {acc_loss}')
print(f'relative validation loss: {valid_loss}')
print(f"--------------------------------------------------------------------------------")
print(f"Epoch {epoch} summary:")
print(f"time taken: {epoch_time}")
print(f"accumulated training loss: {acc_loss}")
print(f"relative validation loss: {valid_loss}")
if wandb.run is not None:
current_lr = optimizer.param_groups[0]['lr']
current_lr = optimizer.param_groups[0]["lr"]
wandb.log({"loss": acc_loss, "validation loss": valid_loss, "learning rate": current_lr})
train_time = time.time() - train_start
print(f'--------------------------------------------------------------------------------')
print(f'done. Training took {train_time}.')
print(f"--------------------------------------------------------------------------------")
print(f"done. Training took {train_time}.")
return valid_loss
def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
# set seed
......@@ -347,14 +365,14 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.cuda.manual_seed(333)
# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(device.index)
# 1 hour prediction steps
dt = 1*3600
dt = 1 * 3600
dt_solver = 150
nsteps = dt//dt_solver
nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
......@@ -371,27 +389,39 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
metrics = {}
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
models["sfno_sc3_layer4_e16_linskip_nomlp"] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
num_layers=4, scale_factor=3, embed_dim=16, operator_type='driscoll-healy',
big_skip=False, pos_embed=False, use_mlp=False, normalization_layer="none")
# models["sfno_sc3_layer4_e256_noskip_mlp"] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
# num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy',
# big_skip=False, pos_embed=False, use_mlp=True, normalization_layer="none")
# from torch_harmonics.examples.sfno.models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# # U-Net if installed
# from models.unet import UNet
# models['unet_baseline'] = partial(UNet)
# SFNO models
# models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', img_size=(nlat, nlon), grid="equiangular",
# num_layers=4, scale_factor=3, embed_dim=256, operator_type='driscoll-healy')
# # FNO models
# models['fno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='fft', img_size=(nlat, nlon), grid="equiangular",
# num_layers=4, scale_factor=3, embed_dim=256, operator_type='diagonal')
from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO
# models["sfno_sc2_layers6_e32"] = partial(
# SFNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=6,
# scale_factor=1,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# normalization_layer="none",
# )
models["lsno_sc2_layers6_e32"] = partial(
LSNO,
spectral_transform="sht",
img_size=(nlat, nlon),
grid="equiangular",
num_layers=6,
scale_factor=1,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
)
# iterate over models and train each model
root_path = os.path.dirname(__file__)
......@@ -404,61 +434,64 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
metrics[model_name] = {}
num_params = count_parameters(model)
print(f'number of trainable params: {num_params}')
metrics[model_name]['num_params'] = num_params
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, 'checkpoints/'+model_name)))
model.load_state_dict(torch.load(os.path.join(root_path, "checkpoints/" + model_name), weights_only=True))
# run the training
if train:
run = wandb.init(project="sfno ablations spherical swe", group=model_name, name=model_name + '_' + str(time.time()), config=model_handle.keywords)
run = wandb.init(project="sfno ablations spherical swe", group=model_name, name=model_name + "_" + str(time.time()), config=model_handle.keywords)
# optimizer:
optimizer = torch.optim.Adam(model.parameters(), lr=3E-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
gscaler = torch.GradScaler("cuda", enabled=enable_amp)
start_time = time.time()
print(f'Training {model_name}, single step')
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn='l2', enable_amp=enable_amp, log_grads=log_grads)
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
# # multistep training
# print(f'Training {model_name}, two step')
# optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# gscaler = amp.GradScaler(enabled=enable_amp)
# gscaler = torch.GradScaler(enabled=enable_amp)
# dataloader.dataset.nsteps = 2 * dt//dt_solver
# train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, nfuture=1, enable_amp=enable_amp)
# train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=5, nfuture=1, enable_amp=enable_amp)
# dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time
run.finish()
torch.save(model.state_dict(), os.path.join(root_path, 'checkpoints/'+model_name))
torch.save(model.state_dict(), os.path.join(root_path, "checkpoints/" + model_name))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path,'figures/'+model_name), nsteps=nsteps, autoreg_steps=10)
metrics[model_name]['loss_mean'] = np.mean(losses)
metrics[model_name]['loss_std'] = np.std(losses)
metrics[model_name]['fno_time_mean'] = np.mean(fno_times)
metrics[model_name]['fno_time_std'] = np.std(fno_times)
metrics[model_name]['nwp_time_mean'] = np.mean(nwp_times)
metrics[model_name]['nwp_time_std'] = np.std(nwp_times)
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path, "figures/" + model_name), nsteps=nsteps, autoreg_steps=30)
metrics[model_name]["loss_mean"] = np.mean(losses)
metrics[model_name]["loss_std"] = np.std(losses)
metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
metrics[model_name]["fno_time_std"] = np.std(fno_times)
metrics[model_name]["nwp_time_mean"] = np.mean(nwp_times)
metrics[model_name]["nwp_time_std"] = np.std(nwp_times)
if train:
metrics[model_name]['training_time'] = training_time
metrics[model_name]["training_time"] = training_time
df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, 'output_data/metrics.pkl'))
df.to_pickle(os.path.join(root_path, "output_data/metrics.pkl"))
if __name__ == "__main__":
import torch.multiprocessing as mp
mp.set_start_method('forkserver', force=True)
mp.set_start_method("forkserver", force=True)
# main(train=False, load_checkpoint=True, enable_amp=False, log_grads=0)
main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0)
......@@ -31,3 +31,4 @@
from .utils.pde_dataset import PdeDataset
from .models.sfno import SphericalFourierNeuralOperatorNet
from .models.local_sfno import LocalSphericalNeuralOperatorNet
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import torch.nn as nn
import torch.amp as amp
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .layers import *
from functools import partial
class DiscreteContinuousEncoder(nn.Module):
def __init__(
self,
inp_shape=(721, 1440),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
groups=1,
bias=False,
):
super().__init__()
# set up local convolution
self.conv = DiscreteContinuousConvS2(
inp_chans,
out_chans,
in_shape=inp_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
grid_in=grid_in,
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
)
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self.conv(x)
x = x.to(dtype=dtype)
return x
class DiscreteContinuousDecoder(nn.Module):
def __init__(
self,
inp_shape=(480, 960),
out_shape=(721, 1440),
grid_in="equiangular",
grid_out="equiangular",
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
groups=1,
bias=False,
):
super().__init__()
# set up
self.sht = RealSHT(*inp_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
# set up DISCO convolution
self.convt = DiscreteContinuousConvTransposeS2(
inp_chans,
out_chans,
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
grid_in=grid_out,
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def _upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x))
def forward(self, x):
dtype = x.dtype
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self._upscale_sht(x)
x = self.convt(x)
x = x.to(dtype=dtype)
return x
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
bias=bias,
)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
conv_type="local",
operator_type="driscoll-healy",
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.ReLU,
norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="None",
outer_skip="linear",
use_mlp=True,
disco_kernel_shape=[2, 4],
):
super().__init__()
if act_layer == nn.Identity:
gain_factor = 1.0
else:
gain_factor = 2.0
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
# convolution layer
if conv_type == "local":
self.local_conv = DiscreteContinuousConvS2(
input_dim,
output_dim,
in_shape=(forward_transform.nlat, forward_transform.nlon),
out_shape=(inverse_transform.nlat, inverse_transform.nlon),
kernel_shape=disco_kernel_shape,
grid_in=forward_transform.grid,
grid_out=inverse_transform.grid,
bias=False,
theta_cutoff=(disco_kernel_shape[0] + 1) * torch.pi / float(forward_transform.nlat - 1) / math.sqrt(2),
)
elif conv_type == "global":
self.global_conv = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=False,
)
else:
raise ValueError(f"Unknown convolution type {conv_type}")
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif inner_skip == "identity":
assert input_dim == output_dim
self.inner_skip = nn.Identity()
elif inner_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
self.act_layer = act_layer()
# first normalisation layer
self.norm0 = norm_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(
in_features=output_dim,
out_features=input_dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop_rate=drop_rate,
checkpointing=False,
gain=gain_factor,
)
if outer_skip == "linear":
self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif outer_skip == "identity":
assert input_dim == output_dim
self.outer_skip = nn.Identity()
elif outer_skip == "none":
pass
else:
raise ValueError(f"Unknown skip connection type {outer_skip}")
# second normalisation layer
self.norm1 = norm_layer()
def forward(self, x):
residual = x
if hasattr(self, "global_conv"):
x, _ = self.global_conv(x)
elif hasattr(self, "local_conv"):
x = self.local_conv(x)
x = self.norm0(x)
if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"):
x = self.act_layer(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
x = self.norm1(x)
x = self.drop_path(x)
if hasattr(self, "outer_skip"):
x = x + self.outer_skip(residual)
return x
class LocalSphericalNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
both linear and non-linear variants.
Parameters
----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
kernel_shape: tuple, int
scale_factor : int, optional
Scale factor to use, by default 3
in_chans : int, optional
Number of input channels, by default 3
out_chans : int, optional
Number of output channels, by default 3
embed_dim : int, optional
Dimension of the embeddings, by default 256
num_layers : int, optional
Number of layers in the network, by default 4
activation_function : str, optional
Activation function to use, by default "gelu"
encoder_kernel_shape : int, optional
size of the encoder kernel
use_mlp : int, optional
Whether to use MLPs in the SFNO blocks, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path_rate : float, optional
Dropout path rate, by default 0.0
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm"
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
big_skip : bool, optional
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
pos_embed : bool, optional
Whether to use positional embedding, by default True
Example:
--------
>>> model = SphericalFourierNeuralOperatorNet(
... img_shape=(128, 256),
... scale_factor=4,
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
"""
def __init__(
self,
spectral_transform="sht",
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
scale_factor=4,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="relu",
kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4],
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False,
):
super().__init__()
self.spectral_transform = spectral_transform
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
self.embed_dim = embed_dim
self.num_layers = num_layers
self.encoder_kernel_shape = encoder_kernel_shape
self.hard_thresholding_fraction = hard_thresholding_fraction
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.big_skip = big_skip
self.factorization = factorization
self.separable = (separable,)
self.rank = rank
# activation function
if activation_function == "relu":
self.activation_function = nn.ReLU
elif activation_function == "gelu":
self.activation_function = nn.GELU
# for debugging purposes
elif activation_function == "identity":
self.activation_function = nn.Identity
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size
self.h = self.img_size[0] // scale_factor
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
if self.normalization_layer == "layer_norm":
norm_layer0 = partial(nn.LayerNorm, normalized_shape=(self.img_size[0], self.img_size[1]), eps=1e-6)
norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h, self.w), eps=1e-6)
elif self.normalization_layer == "instance_norm":
norm_layer0 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
norm_layer1 = partial(nn.InstanceNorm2d, num_features=self.embed_dim, eps=1e-6, affine=True, track_running_stats=False)
elif self.normalization_layer == "none":
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, self.w))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.h, 1))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "const":
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
nn.init.constant_(self.pos_embed, 0.0)
else:
self.pos_embed = None
# encoder
self.encoder = DiscreteContinuousConvS2(
self.in_chans,
self.embed_dim,
self.img_size,
(self.h, self.w),
self.encoder_kernel_shape,
groups=1,
grid_in=grid,
grid_out="legendre-gauss",
bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
)
# # encoder
# self.encoder = DiscreteContinuousEncoder(
# inp_shape=self.img_size,
# out_shape=(self.h, self.w),
# grid_in=grid,
# grid_out="legendre-gauss",
# inp_chans=self.in_chans,
# out_chans=self.embed_dim,
# kernel_shape=self.encoder_kernel_shape,
# groups=1,
# bias=False,
# )
# prepare the spectral transform
if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
else:
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers - 1
inner_skip = "none"
outer_skip = "identity"
if first_layer:
norm_layer = norm_layer1
elif last_layer:
norm_layer = norm_layer0
else:
norm_layer = norm_layer1
block = SphericalNeuralOperatorBlock(
self.trans,
self.itrans,
self.embed_dim,
self.embed_dim,
conv_type="global" if i % 2 == 0 else "local",
operator_type=self.operator_type,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
disco_kernel_shape=kernel_shape,
)
self.blocks.append(block)
# # decoder
# self.decoder = DiscreteContinuousConvTransposeS2(
# self.embed_dim,
# self.out_chans,
# (self.h, self.w),
# self.img_size,
# self.encoder_kernel_shape,
# groups=1,
# grid_in="legendre-gauss",
# grid_out=grid,
# bias=False,
# theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
# )
# decoder
self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w),
out_shape=self.img_size,
grid_in="legendre-gauss",
grid_out=grid,
inp_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
groups=1,
bias=False,
)
# # residual prediction
# if self.big_skip:
# self.residual_transform = nn.Conv2d(self.out_chans, self.in_chans, 1, bias=False)
# self.residual_transform.weight.is_shared_mp = ["spatial"]
# self.residual_transform.weight.sharded_dims_mp = [None, None, None, None]
# scale = math.sqrt(0.5 / self.in_chans)
# nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def forward_features(self, x):
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
return x
def forward(self, x):
if self.big_skip:
residual = x
x = self.encoder(x)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.forward_features(x)
x = self.decoder(x)
if self.big_skip:
# x = x + self.residual_transform(residual)
x = x + residual
return x
......@@ -50,64 +50,64 @@ class SpectralFilterLayer(nn.Module):
inverse_transform,
input_dim,
output_dim,
gain = 2.,
operator_type = "diagonal",
hidden_size_factor = 2,
factorization = None,
separable = False,
rank = 1e-2,
bias = True):
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain,
operator_type = operator_type,
bias = bias)
self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain,
operator_type = operator_type,
rank = rank,
factorization = factorization,
separable = separable,
bias = bias)
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise(NotImplementedError)
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
operator_type = "driscoll-healy",
mlp_ratio = 2.,
drop_rate = 0.,
drop_path = 0.,
act_layer = nn.ReLU,
norm_layer = nn.Identity,
factorization = None,
separable = False,
rank = 128,
inner_skip = "linear",
outer_skip = None,
use_mlp = True):
operator_type="driscoll-healy",
mlp_ratio=2.0,
drop_rate=0.0,
drop_path=0.0,
act_layer=nn.ReLU,
norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="linear",
outer_skip=None,
use_mlp=True,
):
super(SphericalFourierNeuralOperatorBlock, self).__init__()
if act_layer == nn.Identity:
......@@ -119,21 +119,23 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
gain_factor /= 2.0
# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
self.filter = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain = gain_factor,
operator_type = operator_type,
hidden_size_factor = mlp_ratio,
factorization = factorization,
separable = separable,
rank = rank,
bias = True)
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=True,
)
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor/input_dim))
nn.init.normal_(self.inner_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif inner_skip == "identity":
assert input_dim == output_dim
self.inner_skip = nn.Identity()
......@@ -148,25 +150,21 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
self.norm0 = norm_layer()
# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.
gain_factor /= 2.0
if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = output_dim,
out_features = input_dim,
hidden_features = mlp_hidden_dim,
act_layer = act_layer,
drop_rate = drop_rate,
checkpointing = False,
gain = gain_factor)
self.mlp = MLP(
in_features=output_dim, out_features=input_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop_rate=drop_rate, checkpointing=False, gain=gain_factor
)
if outer_skip == "linear":
self.outer_skip = nn.Conv2d(input_dim, input_dim, 1, 1)
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor/input_dim))
torch.nn.init.normal_(self.outer_skip.weight, std=math.sqrt(gain_factor / input_dim))
elif outer_skip == "identity":
assert input_dim == output_dim
self.outer_skip = nn.Identity()
......@@ -178,17 +176,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
# second normalisation layer
self.norm1 = norm_layer()
# def init_weights(self, scale):
# if hasattr(self, "inner_skip") and isinstance(self.inner_skip, nn.Conv2d):
# gain_factor = 1.
# scale = (gain_factor / embed_dim)**0.5
# nn.init.normal_(self.inner_skip.weight, mean=0., std=scale)
# self.filter.filter.init_weights(scale)
# else:
# gain_factor = 2.
# scale = (gain_factor / embed_dim)**0.5
# self.filter.filter.init_weights(scale)
def forward(self, x):
x, residual = self.filter(x)
......@@ -213,6 +200,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
return x
class SphericalFourierNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
......@@ -281,29 +269,30 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def __init__(
self,
spectral_transform = "sht",
operator_type = "driscoll-healy",
img_size = (128, 256),
grid = "equiangular",
scale_factor = 3,
in_chans = 3,
out_chans = 3,
embed_dim = 256,
num_layers = 4,
activation_function = "relu",
encoder_layers = 1,
use_mlp = True,
mlp_ratio = 2.,
drop_rate = 0.,
drop_path_rate = 0.,
normalization_layer = "none",
hard_thresholding_fraction = 1.0,
use_complex_kernels = True,
big_skip = False,
factorization = None,
separable = False,
rank = 128,
pos_embed = False):
spectral_transform="sht",
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
scale_factor=3,
in_chans=3,
out_chans=3,
embed_dim=256,
num_layers=4,
activation_function="relu",
encoder_layers=1,
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
drop_path_rate=0.0,
normalization_layer="none",
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False,
):
super(SphericalFourierNeuralOperatorNet, self).__init__()
......@@ -322,7 +311,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.factorization = factorization
self.separable = separable,
self.separable = (separable,)
self.rank = rank
# activation function
......@@ -341,7 +330,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.w = self.img_size[1] // scale_factor
# dropout
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0. else nn.Identity()
self.pos_drop = nn.Dropout(p=drop_rate) if drop_rate > 0.0 else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)]
# pick norm layer
......@@ -357,7 +346,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
if pos_embed == "latlon" or pos_embed==True:
if pos_embed == "latlon" or pos_embed == True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
nn.init.constant_(self.pos_embed, 0.0)
elif pos_embed == "lat":
......@@ -369,35 +358,24 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
self.pos_embed = None
# # encoder
# encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# encoder = MLP(in_features = self.in_chans,
# out_features = self.embed_dim,
# hidden_features = encoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# self.encoder = encoder
# construct an encoder with num_encoder_layers
num_encoder_layers = 1
encoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.in_chans
encoder_layers = []
for l in range(num_encoder_layers-1):
for l in range(num_encoder_layers - 1):
fc = nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)
# initialize the weights correctly
scale = math.sqrt(2. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
encoder_layers.append(self.activation_function())
current_dim = encoder_hidden_dim
fc = nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
......@@ -407,7 +385,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w//2 * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
......@@ -426,13 +404,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise(ValueError("Unknown spectral transform"))
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
last_layer = i == self.num_layers-1
last_layer = i == self.num_layers - 1
forward_transform = self.trans_down if first_layer else self.trans
inverse_transform = self.itrans_up if last_layer else self.itrans
......@@ -447,52 +425,45 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
norm_layer = norm_layer1
block = SphericalFourierNeuralOperatorBlock(forward_transform,
block = SphericalFourierNeuralOperatorBlock(
forward_transform,
inverse_transform,
self.embed_dim,
self.embed_dim,
operator_type = self.operator_type,
mlp_ratio = mlp_ratio,
drop_rate = drop_rate,
drop_path = dpr[i],
act_layer = self.activation_function,
norm_layer = norm_layer,
inner_skip = inner_skip,
outer_skip = outer_skip,
use_mlp = use_mlp,
factorization = self.factorization,
separable = self.separable,
rank = self.rank)
operator_type=self.operator_type,
mlp_ratio=mlp_ratio,
drop_rate=drop_rate,
drop_path=dpr[i],
act_layer=self.activation_function,
norm_layer=norm_layer,
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
)
self.blocks.append(block)
# # decoder
# decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
# self.decoder = MLP(in_features = self.embed_dim + self.big_skip*self.in_chans,
# out_features = self.out_chans,
# hidden_features = decoder_hidden_dim,
# act_layer = self.activation_function,
# drop_rate = drop_rate,
# checkpointing = False)
# construct an decoder with num_decoder_layers
num_decoder_layers = 1
decoder_hidden_dim = int(self.embed_dim * mlp_ratio)
current_dim = self.embed_dim + self.big_skip*self.in_chans
current_dim = self.embed_dim + self.big_skip * self.in_chans
decoder_layers = []
for l in range(num_decoder_layers-1):
for l in range(num_decoder_layers - 1):
fc = nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)
# initialize the weights correctly
scale = math.sqrt(2. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(2.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
decoder_layers.append(self.activation_function())
current_dim = decoder_hidden_dim
fc = nn.Conv2d(current_dim, self.out_chans, 1, bias=False)
scale = math.sqrt(1. / current_dim)
nn.init.normal_(fc.weight, mean=0., std=scale)
scale = math.sqrt(1.0 / current_dim)
nn.init.normal_(fc.weight, mean=0.0, std=scale)
if fc.bias is not None:
nn.init.constant_(fc.bias, 0.0)
decoder_layers.append(fc)
......@@ -529,5 +500,3 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
x = self.decoder(x)
return x
......@@ -239,7 +239,7 @@ class ShallowWaterSolver(nn.Module):
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed
llimit = mlimit = 80
llimit = mlimit = 120
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment