"vscode:/vscode.git/clone" did not exist on "5221a3883c53a6dd8fb75e3a6ddf52e71e4a5127"
Commit 5d00e2b4 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

simplified SFNO example and by removing factorized versions

parent 3f125603
......@@ -6,6 +6,8 @@
* Changing default grid in all SHT routines to `equiangular`
* Hotfix to the numpy version requirements
* Changing default grid in all SHT routines to `equiangular`, which makes it consistent with DISCO convolutions
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* Reworked DISCO filter basis datastructure
* Support for new filter basis types
......
......@@ -259,7 +259,7 @@ If you use `torch-harmonics` in an academic paper, please cite [1]
<a id="1">[1]</a>
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
arXiv 2306.0383, 2023.
International Conference on Machine Learning, 2023. [arxiv link](https://arxiv.org/abs/2306.03838)
<a id="1">[2]</a>
Schaeffer N.;
......
......@@ -161,6 +161,10 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
model.eval()
# make output
if not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)
losses = np.zeros(nics)
fno_times = np.zeros(nics)
nwp_times = np.zeros(nics)
......@@ -178,18 +182,24 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
prd = prd.unsqueeze(0)
uspec = ic.clone()
# add IC to power spectrum series
prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
ref_coeffs = [prd_coeffs[0].clone()]
# ML model
start_time = time.time()
for i in range(1, autoreg_steps + 1):
# evaluate the ML model
prd = model(prd)
prd_coeffs.append(dataset.sht(prd[0, plot_channel]).detach().cpu().clone())
if iic == nics - 1 and nskip > 0 and i % nskip == 0:
# do plotting
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root + "_pred_" + str(i // nskip) + ".png")
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
plt.close()
fno_times[iic] = time.time() - start_time
......@@ -201,21 +211,20 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
# advance classical model
uspec = dataset.solver.timestep(uspec, nsteps)
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref_coeffs.append(dataset.sht(ref[plot_channel]).detach().cpu().clone())
if iic == nics - 1 and i % nskip == 0 and nskip > 0:
fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4)
plt.savefig(path_root + "_truth_" + str(i // nskip) + ".png")
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
plt.close()
nwp_times[iic] = time.time() - start_time
# compute power spectrum and add it to the buffers
prd_coeffs = dataset.solver.sht(prd[0, plot_channel])
ref_coeffs = dataset.solver.sht(ref[plot_channel])
prd_mean_coeffs.append(prd_coeffs)
ref_mean_coeffs.append(ref_coeffs)
prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec)
......@@ -223,21 +232,29 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
# compute the averaged powerspectra of prediction and reference
prd_mean_coeffs = torch.stack(prd_mean_coeffs).abs().pow(2).mean(dim=0)
ref_mean_coeffs = torch.stack(ref_mean_coeffs).abs().pow(2).mean(dim=0)
with torch.no_grad():
prd_mean_coeffs = torch.stack(prd_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
ref_mean_coeffs = torch.stack(ref_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
prd_mean_coeffs[..., 1:] *= 2.0
ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).detach().cpu()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).detach().cpu()
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).contiguous()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).contiguous()
# split the stuff
prd_mean_ps = [x.squeeze() for x in list(torch.split(prd_mean_ps, 1, dim=0))]
ref_mean_ps = [x.squeeze() for x in list(torch.split(ref_mean_ps, 1, dim=0))]
# compute the averaged powerspectrum
for step, (pps, rps) in enumerate(zip(prd_mean_ps, ref_mean_ps)):
fig = plt.figure(figsize=(7.5, 6))
plt.loglog(prd_mean_ps, label="prediction")
plt.loglog(ref_mean_ps, label="reference")
plt.semilogy(pps, label="prediction")
plt.semilogy(rps, label="reference")
plt.xlabel("$l$")
plt.ylabel("powerspectrum")
plt.legend()
plt.savefig(path_root + "_powerspectrum.png")
plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
fig.clf()
plt.close()
return losses, fno_times, nwp_times
......@@ -364,6 +381,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# set parameters
nfuture=0
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
......@@ -373,7 +393,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600
dt_solver = 150
nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, normalize=True)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, grid="legendre-gauss", normalize=True)
dataset.sht = RealSHT(nlat=257, nlon=512, grid= "equiangular").to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
......@@ -391,38 +412,54 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO
# models["sfno_sc2_layers6_e32"] = partial(
# models[f"sfno_sc2_layers4_e32_nomlp"] = partial(
# SFNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=6,
# scale_factor=1,
# # hard_thresholding_fraction=0.8,
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# use_mlp=False,
# normalization_layer="none",
# )
models["lsno_sc2_layers6_e32"] = partial(
LSNO,
spectral_transform="sht",
models[f"sfno_sc2_layers4_e32_nomlp_leggauss"] = partial(
SFNO,
img_size=(nlat, nlon),
grid="equiangular",
num_layers=6,
scale_factor=1,
grid="legendre-gauss",
# hard_thresholding_fraction=0.8,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=True,
big_skip=False,
pos_embed=False,
use_mlp=True,
use_mlp=False,
normalization_layer="none",
)
# models[f"lsno_sc1_layers4_e32_nomlp"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=False,
# normalization_layer="none",
# )
# iterate over models and train each model
root_path = os.path.dirname(__file__)
for model_name, model_handle in models.items():
......@@ -437,8 +474,12 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params
exp_dir = os.path.join(root_path, 'checkpoints', model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)
if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, "checkpoints/" + model_name), weights_only=True))
model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
# run the training
if train:
......@@ -454,27 +495,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
# # multistep training
# print(f'Training {model_name}, two step')
# optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# gscaler = torch.GradScaler(enabled=enable_amp)
# dataloader.dataset.nsteps = 2 * dt//dt_solver
# train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=5, nfuture=1, enable_amp=enable_amp)
# dataloader.dataset.nsteps = 1 * dt//dt_solver
if nfuture > 0:
print(f'Training {model_name}, {nfuture} step')
optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time
run.finish()
torch.save(model.state_dict(), os.path.join(root_path, "checkpoints/" + model_name))
torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path, "figures/" + model_name), nsteps=nsteps, autoreg_steps=30)
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
metrics[model_name]["loss_mean"] = np.mean(losses)
metrics[model_name]["loss_std"] = np.std(losses)
metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
......@@ -485,7 +526,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
metrics[model_name]["training_time"] = training_time
df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, "output_data/metrics.pkl"))
if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))
if __name__ == "__main__":
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.3"
__version__ = "0.7.3a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@torch.jit.script
def contract_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kixy->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def contract_blockdiag(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kixyz->bkxz", ac, bc)
return torch.view_as_real(res)
# Helper routines for the non-linear FNOs (Attention-like)
@torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul1d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bix,io->box", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd1d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul1d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd1d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul1d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
# Helper routines for FFT MLPs
@torch.jit.script
def compl_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,ior->srboxy", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,io->boxy", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul2d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd2d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul2d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
@torch.jit.script
def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
out = torch.einsum("bixy,io->boxy", a, b)
return out
@torch.jit.script
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.core import FactorizedTensor
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
# batch-size, in_channels, x, y...
x_syms = list(einsum_symbols[:order])
# in_channels, out_channels, x, y...
weight_syms = list(x_syms[1:]) # no batch-size
# batch-size, out_channels, x, y...
if separable:
out_syms = [x_syms[0]] + list(weight_syms)
else:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms)
if not torch.is_tensor(weight):
weight = weight.to_tensor()
return tl.einsum(eq, x, weight)
def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
rank_sym = einsum_symbols[order]
out_sym = einsum_symbols[order+1]
out_syms = list(x_syms)
if separable:
factor_syms = [einsum_symbols[1]+rank_sym] #in only
else:
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+rank_sym, out_sym+rank_sym] #in, out
factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'driscoll-healy':
factor_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
def _contract_tucker(x, tucker_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
out_sym = einsum_symbols[order]
out_syms = list(x_syms)
if separable:
core_syms = einsum_symbols[order+1:2*order]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ...
else:
core_syms = einsum_symbols[order+1:2*order+1]
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in
factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker")
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)
def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = list(einsum_symbols[:order])
weight_syms = list(x_syms[1:]) # no batch-size
if not separable:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
else:
out_syms = list(x_syms)
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
rank_syms = list(einsum_symbols[order+2:])
tt_syms = []
for i, s in enumerate(weight_syms):
tt_syms.append([rank_syms[i], s, rank_syms[i+1]])
eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, *tt_weight.factors)
def get_contract_fun(weight, implementation='reconstructed', separable=False):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if implementation == 'reconstructed':
return _contract_dense
elif implementation == 'factorized':
if torch.is_tensor(weight):
return _contract_dense
elif isinstance(weight, FactorizedTensor):
if weight.name.lower() == 'complexdense':
return _contract_dense
elif weight.name.lower() == 'complextucker':
return _contract_tucker
elif weight.name.lower() == 'complextt':
return _contract_tt
elif weight.name.lower() == 'complexcp':
return _contract_cp
else:
raise ValueError(f'Got unexpected factorized weight type {weight.name}')
else:
raise ValueError(f'Got unexpected weight type of class {weight.__class__.__name__}')
else:
raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')
......@@ -36,16 +36,8 @@ from torch.utils.checkpoint import checkpoint
import math
from torch_harmonics import *
from .contractions import *
from .activations import *
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
# tl.set_backend("pytorch")
# use_opt_einsum("optimal")
from tltorch.factorized_tensors.core import FactorizedTensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
......@@ -237,7 +229,7 @@ class SpectralConvS2(nn.Module):
operator_type = "driscoll-healy",
lr_scale_exponent = 0,
bias = False):
super(SpectralConvS2, self).__init__()
super().__init__()
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
......@@ -258,123 +250,19 @@ class SpectralConvS2(nn.Module):
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = torch.view_as_real(x)
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
gain = 2.,
operator_type = "driscoll-healy",
rank = 0.2,
factorization = None,
separable = False,
implementation = "factorized",
decomposition_kwargs=dict(),
bias = False):
super(SpectralConvS2, self).__init__()
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# Make sure we are using a Complex Factorized Tensor
if factorization is None:
factorization = "Dense" # No factorization
if not factorization.lower().startswith("complex"):
factorization = f"Complex{factorization}"
# remember factorization details
self.operator_type = operator_type
self.rank = rank
self.factorization = factorization
self.separable = separable
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [out_channels]
if not self.separable:
weight_shape += [in_channels]
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
self.contract_func = "...ilm,oil->...olm"
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights
scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale)
# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
......@@ -390,7 +278,7 @@ class FactorizedSpectralConvS2(nn.Module):
if self.scale_residual:
residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type)
x = torch.einsum(self.contract_func, x, self.weight)
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
......
......@@ -51,6 +51,7 @@ class DiscreteContinuousEncoder(nn.Module):
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
basis_type="piecewise linear",
groups=1,
bias=False,
):
......@@ -63,11 +64,12 @@ class DiscreteContinuousEncoder(nn.Module):
in_shape=inp_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in,
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
theta_cutoff=4*math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
)
def forward(self, x):
......@@ -91,6 +93,7 @@ class DiscreteContinuousDecoder(nn.Module):
inp_chans=2,
out_chans=2,
kernel_shape=[3, 4],
basis_type="piecewise linear",
groups=1,
bias=False,
):
......@@ -107,11 +110,12 @@ class DiscreteContinuousDecoder(nn.Module):
in_shape=out_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out,
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
theta_cutoff=4*math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
......@@ -131,58 +135,6 @@ class DiscreteContinuousDecoder(nn.Module):
return x
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
bias=bias,
)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalNeuralOperatorBlock(nn.Module):
"""
......@@ -202,13 +154,11 @@ class SphericalNeuralOperatorBlock(nn.Module):
drop_path=0.0,
act_layer=nn.ReLU,
norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="None",
outer_skip="linear",
use_mlp=True,
disco_kernel_shape=[2, 4],
disco_basis_type="piecewise linear",
):
super().__init__()
......@@ -228,25 +178,14 @@ class SphericalNeuralOperatorBlock(nn.Module):
in_shape=(forward_transform.nlat, forward_transform.nlon),
out_shape=(inverse_transform.nlat, inverse_transform.nlon),
kernel_shape=disco_kernel_shape,
basis_type=disco_basis_type,
grid_in=forward_transform.grid,
grid_out=inverse_transform.grid,
bias=False,
theta_cutoff=(disco_kernel_shape[0] + 1) * torch.pi / float(forward_transform.nlat - 1) / math.sqrt(2),
theta_cutoff=4*math.sqrt(2) * torch.pi / float(inverse_transform.nlat - 1),
)
elif conv_type == "global":
self.global_conv = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=False,
)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
else:
raise ValueError(f"Unknown convolution type {conv_type}")
......@@ -261,8 +200,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
self.act_layer = act_layer()
# first normalisation layer
self.norm0 = norm_layer()
......@@ -313,9 +250,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"):
x = self.act_layer(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
......@@ -331,11 +265,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
class LocalSphericalNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
both linear and non-linear variants.
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
as well as in the encoder and decoders.
Parameters
----------
-----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
......@@ -373,17 +309,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
pos_embed : bool, optional
Whether to use positional embedding, by default True
Example:
--------
Example
-----------
>>> model = SphericalFourierNeuralOperatorNet(
... img_shape=(128, 256),
... scale_factor=4,
......@@ -394,6 +324,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
References
-----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
.. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
"""
def __init__(
......@@ -402,6 +343,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=4,
in_chans=3,
out_chans=3,
......@@ -410,6 +352,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
activation_function="relu",
kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4],
disco_basis_type="piecewise linear",
use_mlp=True,
mlp_ratio=2.0,
drop_rate=0.0,
......@@ -418,9 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False,
):
super().__init__()
......@@ -429,6 +369,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
......@@ -439,9 +380,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.normalization_layer = normalization_layer
self.use_mlp = use_mlp
self.big_skip = big_skip
self.factorization = factorization
self.separable = (separable,)
self.rank = rank
# activation function
if activation_function == "relu":
......@@ -455,7 +393,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
......@@ -494,9 +432,10 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.img_size,
(self.h, self.w),
self.encoder_kernel_shape,
basis_type=disco_basis_type,
groups=1,
grid_in=grid,
grid_out="legendre-gauss",
grid_out=grid_internal,
bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
)
......@@ -506,7 +445,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# inp_shape=self.img_size,
# out_shape=(self.h, self.w),
# grid_in=grid,
# grid_out="legendre-gauss",
# grid_out=grid_internal,
# inp_chans=self.in_chans,
# out_chans=self.embed_dim,
# kernel_shape=self.encoder_kernel_shape,
......@@ -520,8 +459,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
else:
raise (ValueError("Unknown spectral transform"))
......@@ -556,10 +495,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
disco_kernel_shape=kernel_shape,
disco_basis_type=disco_basis_type,
)
self.blocks.append(block)
......@@ -582,11 +519,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w),
out_shape=self.img_size,
grid_in="legendre-gauss",
grid_in=grid_internal,
grid_out=grid,
inp_chans=self.embed_dim,
out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape,
basis_type=disco_basis_type,
groups=1,
bias=False,
)
......
......@@ -39,51 +39,6 @@ from .layers import *
from functools import partial
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
......@@ -108,7 +63,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
outer_skip=None,
use_mlp=True,
):
super(SphericalFourierNeuralOperatorBlock, self).__init__()
super().__init__()
if act_layer == nn.Identity:
gain_factor = 1.0
......@@ -118,20 +73,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0
# convolution layer
self.filter = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=True,
)
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
......@@ -144,8 +86,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else:
raise ValueError(f"Unknown skip connection type {inner_skip}")
self.act_layer = act_layer()
# first normalisation layer
self.norm0 = norm_layer()
......@@ -178,16 +118,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
def forward(self, x):
x, residual = self.filter(x)
x, residual = self.global_conv(x)
x = self.norm0(x)
if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"):
x = self.act_layer(x)
if hasattr(self, "mlp"):
x = self.mlp(x)
......@@ -203,13 +140,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
class SphericalFourierNeuralOperatorNet(nn.Module):
"""
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO,
both linear and non-linear variants.
SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
and approximately equivariant architecture.
Parameters
----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional
......@@ -244,12 +180,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
pos_embed : bool, optional
Whether to use positional embedding, by default True
......@@ -265,14 +195,20 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
References
-----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
"""
def __init__(
self,
spectral_transform="sht",
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=3,
in_chans=3,
out_chans=3,
......@@ -288,18 +224,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
hard_thresholding_fraction=1.0,
use_complex_kernels=True,
big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False,
):
super(SphericalFourierNeuralOperatorNet, self).__init__()
super().__init__()
self.spectral_transform = spectral_transform
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor
self.in_chans = in_chans
self.out_chans = out_chans
......@@ -310,9 +243,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.use_mlp = use_mlp
self.encoder_layers = encoder_layers
self.big_skip = big_skip
self.factorization = factorization
self.separable = (separable,)
self.rank = rank
# activation function
if activation_function == "relu":
......@@ -326,7 +256,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor
self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor
# dropout
......@@ -381,30 +311,17 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers.append(fc)
self.encoder = nn.Sequential(*encoder_layers)
# prepare the spectral transform
if self.spectral_transform == "sht":
# compute the modes for the sht
modes_lat = self.h
# due to some spectral artifacts with cufft, we substract one mode here
modes_lon = (self.w // 2 + 1) -1
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
elif self.spectral_transform == "fft":
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction)
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise (ValueError("Unknown spectral transform"))
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
......@@ -439,9 +356,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
inner_skip=inner_skip,
outer_skip=outer_skip,
use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
)
self.blocks.append(block)
......
......@@ -35,10 +35,23 @@ from math import ceil
from ...shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random',
num_examples=32, device=torch.device('cpu'), normalize=True, stream=None):
def __init__(
self,
dt,
nsteps,
dims=(384, 768),
grid="equiangular",
pde="shallow water equations",
initial_condition="random",
num_examples=32,
device=torch.device("cpu"),
normalize=True,
stream=None,
):
self.num_examples = num_examples
self.device = device
self.stream = stream
......@@ -50,11 +63,11 @@ class PdeDataset(torch.utils.data.Dataset):
self.nsteps = nsteps
self.normalize = normalize
if pde == 'shallow water equations':
lmax = ceil(self.nlat/3)
if pde == "shallow water equations":
lmax = ceil(self.nlat / 3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float()
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid=grid).to(self.device).float()
else:
raise NotImplementedError
......@@ -66,19 +79,19 @@ class PdeDataset(torch.utils.data.Dataset):
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
def __len__(self):
length = self.num_examples if self.ictype == 'random' else 1
length = self.num_examples if self.ictype == "random" else 1
return length
def set_initial_condition(self, ictype='random'):
def set_initial_condition(self, ictype="random"):
self.ictype = ictype
def set_num_examples(self, num_examples=32):
self.num_examples = num_examples
def _get_sample(self):
if self.ictype == 'random':
if self.ictype == "random":
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == 'galewsky':
elif self.ictype == "galewsky":
inp = self.solver.galewsky_initial_condition()
# solve pde for n steps to return the target
......
......@@ -367,6 +367,21 @@ class ShallowWaterSolver(nn.Module):
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05)
elif projection == 'robinson':
import cartopy.crs as ccrs
proj = ccrs.Robinson(central_longitude=0.0)
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05)
else:
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment