Commit 5d00e2b4 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

simplified SFNO example and by removing factorized versions

parent 3f125603
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
* Changing default grid in all SHT routines to `equiangular` * Changing default grid in all SHT routines to `equiangular`
* Hotfix to the numpy version requirements * Hotfix to the numpy version requirements
* Changing default grid in all SHT routines to `equiangular`, which makes it consistent with DISCO convolutions
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* Reworked DISCO filter basis datastructure * Reworked DISCO filter basis datastructure
* Support for new filter basis types * Support for new filter basis types
......
...@@ -259,7 +259,7 @@ If you use `torch-harmonics` in an academic paper, please cite [1] ...@@ -259,7 +259,7 @@ If you use `torch-harmonics` in an academic paper, please cite [1]
<a id="1">[1]</a> <a id="1">[1]</a>
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.; Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere; Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
arXiv 2306.0383, 2023. International Conference on Machine Learning, 2023. [arxiv link](https://arxiv.org/abs/2306.03838)
<a id="1">[2]</a> <a id="1">[2]</a>
Schaeffer N.; Schaeffer N.;
......
...@@ -161,6 +161,10 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10 ...@@ -161,6 +161,10 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
model.eval() model.eval()
# make output
if not os.path.isdir(path_root):
os.makedirs(path_root, exist_ok=True)
losses = np.zeros(nics) losses = np.zeros(nics)
fno_times = np.zeros(nics) fno_times = np.zeros(nics)
nwp_times = np.zeros(nics) nwp_times = np.zeros(nics)
...@@ -178,18 +182,24 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10 ...@@ -178,18 +182,24 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
prd = prd.unsqueeze(0) prd = prd.unsqueeze(0)
uspec = ic.clone() uspec = ic.clone()
# add IC to power spectrum series
prd_coeffs = [dataset.sht(prd[0, plot_channel]).detach().cpu().clone()]
ref_coeffs = [prd_coeffs[0].clone()]
# ML model # ML model
start_time = time.time() start_time = time.time()
for i in range(1, autoreg_steps + 1): for i in range(1, autoreg_steps + 1):
# evaluate the ML model # evaluate the ML model
prd = model(prd) prd = model(prd)
prd_coeffs.append(dataset.sht(prd[0, plot_channel]).detach().cpu().clone())
if iic == nics - 1 and nskip > 0 and i % nskip == 0: if iic == nics - 1 and nskip > 0 and i % nskip == 0:
# do plotting # do plotting
fig = plt.figure(figsize=(7.5, 6)) fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4) dataset.solver.plot_griddata(prd[0, plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(path_root + "_pred_" + str(i // nskip) + ".png") plt.savefig(os.path.join(path_root,'pred_'+str(i//nskip)+'.png'))
plt.close() plt.close()
fno_times[iic] = time.time() - start_time fno_times[iic] = time.time() - start_time
...@@ -201,21 +211,20 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10 ...@@ -201,21 +211,20 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
# advance classical model # advance classical model
uspec = dataset.solver.timestep(uspec, nsteps) uspec = dataset.solver.timestep(uspec, nsteps)
ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var) ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref_coeffs.append(dataset.sht(ref[plot_channel]).detach().cpu().clone())
if iic == nics - 1 and i % nskip == 0 and nskip > 0: if iic == nics - 1 and i % nskip == 0 and nskip > 0:
fig = plt.figure(figsize=(7.5, 6)) fig = plt.figure(figsize=(7.5, 6))
dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4) dataset.solver.plot_griddata(ref[plot_channel], fig, vmax=4, vmin=-4, projection="robinson")
plt.savefig(path_root + "_truth_" + str(i // nskip) + ".png") plt.savefig(os.path.join(path_root,'truth_'+str(i//nskip)+'.png'))
plt.close() plt.close()
nwp_times[iic] = time.time() - start_time nwp_times[iic] = time.time() - start_time
# compute power spectrum and add it to the buffers # compute power spectrum and add it to the buffers
prd_coeffs = dataset.solver.sht(prd[0, plot_channel]) prd_mean_coeffs.append(torch.stack(prd_coeffs, 0))
ref_coeffs = dataset.solver.sht(ref[plot_channel]) ref_mean_coeffs.append(torch.stack(ref_coeffs, 0))
prd_mean_coeffs.append(prd_coeffs)
ref_mean_coeffs.append(ref_coeffs)
# ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var) # ref = (dataset.solver.spec2grid(uspec) - inp_mean) / torch.sqrt(inp_var)
ref = dataset.solver.spec2grid(uspec) ref = dataset.solver.spec2grid(uspec)
...@@ -223,22 +232,30 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10 ...@@ -223,22 +232,30 @@ def autoregressive_inference(model, dataset, path_root, nsteps, autoreg_steps=10
losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item() losses[iic] = l2loss_sphere(dataset.solver, prd, ref, relative=True).item()
# compute the averaged powerspectra of prediction and reference # compute the averaged powerspectra of prediction and reference
prd_mean_coeffs = torch.stack(prd_mean_coeffs).abs().pow(2).mean(dim=0) with torch.no_grad():
ref_mean_coeffs = torch.stack(ref_mean_coeffs).abs().pow(2).mean(dim=0) prd_mean_coeffs = torch.stack(prd_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
prd_mean_coeffs[..., 1:] *= 2.0 ref_mean_coeffs = torch.stack(ref_mean_coeffs, dim=0).abs().pow(2).mean(dim=0)
ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).detach().cpu() prd_mean_coeffs[..., 1:] *= 2.0
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).detach().cpu() ref_mean_coeffs[..., 1:] *= 2.0
prd_mean_ps = prd_mean_coeffs.sum(dim=-1).contiguous()
ref_mean_ps = ref_mean_coeffs.sum(dim=-1).contiguous()
# split the stuff
prd_mean_ps = [x.squeeze() for x in list(torch.split(prd_mean_ps, 1, dim=0))]
ref_mean_ps = [x.squeeze() for x in list(torch.split(ref_mean_ps, 1, dim=0))]
# compute the averaged powerspectrum # compute the averaged powerspectrum
fig = plt.figure(figsize=(7.5, 6)) for step, (pps, rps) in enumerate(zip(prd_mean_ps, ref_mean_ps)):
plt.loglog(prd_mean_ps, label="prediction") fig = plt.figure(figsize=(7.5, 6))
plt.loglog(ref_mean_ps, label="reference") plt.semilogy(pps, label="prediction")
plt.xlabel("$l$") plt.semilogy(rps, label="reference")
plt.ylabel("powerspectrum") plt.xlabel("$l$")
plt.legend() plt.ylabel("powerspectrum")
plt.savefig(path_root + "_powerspectrum.png") plt.legend()
plt.close() plt.savefig(os.path.join(path_root,f'powerspectrum_{step}.png'))
fig.clf()
plt.close()
return losses, fno_times, nwp_times return losses, fno_times, nwp_times
...@@ -364,6 +381,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -364,6 +381,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.manual_seed(333) torch.manual_seed(333)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
# set parameters
nfuture=0
# set device # set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -373,7 +393,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -373,7 +393,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600 dt = 1 * 3600
dt_solver = 150 dt_solver = 150
nsteps = dt // dt_solver nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, normalize=True) dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, grid="legendre-gauss", normalize=True)
dataset.sht = RealSHT(nlat=257, nlon=512, grid= "equiangular").to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment # There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True) # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
...@@ -391,38 +412,54 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -391,38 +412,54 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO from torch_harmonics.examples.sfno import LocalSphericalNeuralOperatorNet as LSNO
# models["sfno_sc2_layers6_e32"] = partial( # models[f"sfno_sc2_layers4_e32_nomlp"] = partial(
# SFNO, # SFNO,
# spectral_transform="sht",
# img_size=(nlat, nlon), # img_size=(nlat, nlon),
# grid="equiangular", # grid="equiangular",
# num_layers=6, # # hard_thresholding_fraction=0.8,
# scale_factor=1, # num_layers=4,
# scale_factor=2,
# embed_dim=32, # embed_dim=32,
# operator_type="driscoll-healy", # operator_type="driscoll-healy",
# activation_function="gelu", # activation_function="gelu",
# big_skip=True, # big_skip=True,
# pos_embed=False, # pos_embed=False,
# use_mlp=True, # use_mlp=False,
# normalization_layer="none", # normalization_layer="none",
# ) # )
models["lsno_sc2_layers6_e32"] = partial( models[f"sfno_sc2_layers4_e32_nomlp_leggauss"] = partial(
LSNO, SFNO,
spectral_transform="sht",
img_size=(nlat, nlon), img_size=(nlat, nlon),
grid="equiangular", grid="legendre-gauss",
num_layers=6, # hard_thresholding_fraction=0.8,
scale_factor=1, num_layers=4,
scale_factor=2,
embed_dim=32, embed_dim=32,
operator_type="driscoll-healy", operator_type="driscoll-healy",
activation_function="gelu", activation_function="gelu",
big_skip=True, big_skip=False,
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=False,
normalization_layer="none", normalization_layer="none",
) )
# models[f"lsno_sc1_layers4_e32_nomlp"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid="equiangular",
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=False,
# normalization_layer="none",
# )
# iterate over models and train each model # iterate over models and train each model
root_path = os.path.dirname(__file__) root_path = os.path.dirname(__file__)
for model_name, model_handle in models.items(): for model_name, model_handle in models.items():
...@@ -437,8 +474,12 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -437,8 +474,12 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
print(f"number of trainable params: {num_params}") print(f"number of trainable params: {num_params}")
metrics[model_name]["num_params"] = num_params metrics[model_name]["num_params"] = num_params
exp_dir = os.path.join(root_path, 'checkpoints', model_name)
if not os.path.isdir(exp_dir):
os.makedirs(exp_dir, exist_ok=True)
if load_checkpoint: if load_checkpoint:
model.load_state_dict(torch.load(os.path.join(root_path, "checkpoints/" + model_name), weights_only=True)) model.load_state_dict(torch.load(os.path.join(exp_dir, "checkpoint.pt")))
# run the training # run the training
if train: if train:
...@@ -454,27 +495,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -454,27 +495,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
print(f"Training {model_name}, single step") print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
# # multistep training if nfuture > 0:
# print(f'Training {model_name}, two step') print(f'Training {model_name}, {nfuture} step')
# optimizer = torch.optim.Adam(model.parameters(), lr=5E-5) optimizer = torch.optim.Adam(model.parameters(), lr=5E-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# gscaler = torch.GradScaler(enabled=enable_amp) gscaler = amp.GradScaler(enabled=enable_amp)
# dataloader.dataset.nsteps = 2 * dt//dt_solver dataloader.dataset.nsteps = 2 * dt//dt_solver
# train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=5, nfuture=1, enable_amp=enable_amp) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
# dataloader.dataset.nsteps = 1 * dt//dt_solver dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time training_time = time.time() - start_time
run.finish() run.finish()
torch.save(model.state_dict(), os.path.join(root_path, "checkpoints/" + model_name)) torch.save(model.state_dict(), os.path.join(exp_dir, 'checkpoint.pt'))
# set seed # set seed
torch.manual_seed(333) torch.manual_seed(333)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
with torch.inference_mode(): with torch.inference_mode():
losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(root_path, "figures/" + model_name), nsteps=nsteps, autoreg_steps=30) losses, fno_times, nwp_times = autoregressive_inference(model, dataset, os.path.join(exp_dir,'figures'), nsteps=nsteps, autoreg_steps=30, nics=50)
metrics[model_name]["loss_mean"] = np.mean(losses) metrics[model_name]["loss_mean"] = np.mean(losses)
metrics[model_name]["loss_std"] = np.std(losses) metrics[model_name]["loss_std"] = np.std(losses)
metrics[model_name]["fno_time_mean"] = np.mean(fno_times) metrics[model_name]["fno_time_mean"] = np.mean(fno_times)
...@@ -485,7 +526,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -485,7 +526,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
metrics[model_name]["training_time"] = training_time metrics[model_name]["training_time"] = training_time
df = pd.DataFrame(metrics) df = pd.DataFrame(metrics)
df.to_pickle(os.path.join(root_path, "output_data/metrics.pkl")) if not os.path.isdir(os.path.join(exp_dir, 'output_data',)):
os.makedirs(os.path.join(exp_dir, 'output_data'), exist_ok=True)
df.to_pickle(os.path.join(exp_dir, 'output_data', 'metrics.pkl'))
if __name__ == "__main__": if __name__ == "__main__":
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.3" __version__ = "0.7.3a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
"""
Contains complex contractions wrapped into jit for harmonic layers
"""
@torch.jit.script
def contract_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kixy->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kix->bkxy", ac, bc)
return torch.view_as_real(res)
@torch.jit.script
def contract_blockdiag(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
res = torch.einsum("bixy,kixyz->bkxz", ac, bc)
return torch.view_as_real(res)
# Helper routines for the non-linear FNOs (Attention-like)
@torch.jit.script
def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixs,ior->srbox", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul1d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bix,io->box", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd1d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul1d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd1d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul1d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
# Helper routines for FFT MLPs
@torch.jit.script
def compl_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
tmp = torch.einsum("bixys,ior->srboxy", a, b)
res = torch.stack([tmp[0,0,...] - tmp[1,1,...], tmp[1,0,...] + tmp[0,1,...]], dim=-1)
return res
@torch.jit.script
def compl_mul2d_fwd_c(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ac = torch.view_as_complex(a)
bc = torch.view_as_complex(b)
resc = torch.einsum("bixy,io->boxy", ac, bc)
res = torch.view_as_real(resc)
return res
@torch.jit.script
def compl_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
res = compl_mul2d_fwd(a, b) + c
return res
@torch.jit.script
def compl_muladd2d_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
tmpcc = torch.view_as_complex(compl_mul2d_fwd_c(a, b))
cc = torch.view_as_complex(c)
return torch.view_as_real(tmpcc + cc)
@torch.jit.script
def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
out = torch.einsum("bixy,io->boxy", a, b)
return out
@torch.jit.script
def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
return compl_mul2d_fwd_c(a, b) + c
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.core import FactorizedTensor
einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
# batch-size, in_channels, x, y...
x_syms = list(einsum_symbols[:order])
# in_channels, out_channels, x, y...
weight_syms = list(x_syms[1:]) # no batch-size
# batch-size, out_channels, x, y...
if separable:
out_syms = [x_syms[0]] + list(weight_syms)
else:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq= ''.join(x_syms) + ',' + ''.join(weight_syms) + '->' + ''.join(out_syms)
if not torch.is_tensor(weight):
weight = weight.to_tensor()
return tl.einsum(eq, x, weight)
def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
rank_sym = einsum_symbols[order]
out_sym = einsum_symbols[order+1]
out_syms = list(x_syms)
if separable:
factor_syms = [einsum_symbols[1]+rank_sym] #in only
else:
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+rank_sym, out_sym+rank_sym] #in, out
factor_syms += [xs+rank_sym for xs in x_syms[2:]] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
out_syms[-1] = einsum_symbols[order+2]
factor_syms += [out_syms[-1] + rank_sym]
elif operator_type == 'driscoll-healy':
factor_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + rank_sym + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors)
def _contract_tucker(x, tucker_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = str(einsum_symbols[:order])
out_sym = einsum_symbols[order]
out_syms = list(x_syms)
if separable:
core_syms = einsum_symbols[order+1:2*order]
# factor_syms = [einsum_symbols[1]+core_syms[0]] #in only
factor_syms = [xs+rs for (xs, rs) in zip(x_syms[1:], core_syms)] #x, y, ...
else:
core_syms = einsum_symbols[order+1:2*order+1]
out_syms[1] = out_sym
factor_syms = [einsum_symbols[1]+core_syms[0], out_sym+core_syms[1]] #out, in
factor_syms += [xs+rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] #x, y, ...
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker")
else:
raise ValueError(f"Unkonw operator type {operator_type}")
eq = x_syms + ',' + core_syms + ',' + ','.join(factor_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors)
def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
order = tl.ndim(x)
x_syms = list(einsum_symbols[:order])
weight_syms = list(x_syms[1:]) # no batch-size
if not separable:
weight_syms.insert(1, einsum_symbols[order]) # outputs
out_syms = list(weight_syms)
out_syms[0] = x_syms[0]
else:
out_syms = list(x_syms)
if operator_type == 'diagonal':
pass
elif operator_type == 'block-diagonal':
weight_syms.insert(-1, einsum_symbols[order+1])
out_syms[-1] = weight_syms[-2]
elif operator_type == 'driscoll-healy':
weight_syms.pop()
else:
raise ValueError(f"Unkonw operator type {operator_type}")
rank_syms = list(einsum_symbols[order+2:])
tt_syms = []
for i, s in enumerate(weight_syms):
tt_syms.append([rank_syms[i], s, rank_syms[i+1]])
eq = ''.join(x_syms) + ',' + ','.join(''.join(f) for f in tt_syms) + '->' + ''.join(out_syms)
return tl.einsum(eq, x, *tt_weight.factors)
def get_contract_fun(weight, implementation='reconstructed', separable=False):
"""Generic ND implementation of Fourier Spectral Conv contraction
Parameters
----------
weight : tensorly-torch's FactorizedTensor
implementation : {'reconstructed', 'factorized'}, default is 'reconstructed'
whether to reconstruct the weight and do a forward pass (reconstructed)
or contract directly the factors of the factorized weight with the input (factorized)
Returns
-------
function : (x, weight) -> x * weight in Fourier space
"""
if implementation == 'reconstructed':
return _contract_dense
elif implementation == 'factorized':
if torch.is_tensor(weight):
return _contract_dense
elif isinstance(weight, FactorizedTensor):
if weight.name.lower() == 'complexdense':
return _contract_dense
elif weight.name.lower() == 'complextucker':
return _contract_tucker
elif weight.name.lower() == 'complextt':
return _contract_tt
elif weight.name.lower() == 'complexcp':
return _contract_cp
else:
raise ValueError(f'Got unexpected factorized weight type {weight.name}')
else:
raise ValueError(f'Got unexpected weight type of class {weight.__class__.__name__}')
else:
raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"')
...@@ -36,16 +36,8 @@ from torch.utils.checkpoint import checkpoint ...@@ -36,16 +36,8 @@ from torch.utils.checkpoint import checkpoint
import math import math
from torch_harmonics import * from torch_harmonics import *
from .contractions import *
from .activations import * from .activations import *
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
# tl.set_backend("pytorch")
# use_opt_einsum("optimal")
from tltorch.factorized_tensors.core import FactorizedTensor
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
...@@ -237,7 +229,7 @@ class SpectralConvS2(nn.Module): ...@@ -237,7 +229,7 @@ class SpectralConvS2(nn.Module):
operator_type = "driscoll-healy", operator_type = "driscoll-healy",
lr_scale_exponent = 0, lr_scale_exponent = 0,
bias = False): bias = False):
super(SpectralConvS2, self).__init__() super().__init__()
self.forward_transform = forward_transform self.forward_transform = forward_transform
self.inverse_transform = inverse_transform self.inverse_transform = inverse_transform
...@@ -258,123 +250,19 @@ class SpectralConvS2(nn.Module): ...@@ -258,123 +250,19 @@ class SpectralConvS2(nn.Module):
if self.operator_type == "diagonal": if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon]
from .contractions import contract_diagonal as _contract self.contract_func = "...ilm,oilm->...olm"
elif self.operator_type == "block-diagonal": elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon] weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
from .contractions import contract_blockdiag as _contract self.contract_func = "...ilm,oilnm->...oln"
elif self.operator_type == "driscoll-healy": elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat] weight_shape += [self.modes_lat]
from .contractions import contract_dhconv as _contract self.contract_func = "...ilm,oil->...olm"
else: else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}") raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors # form weight tensors
scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
# get the right contraction function
self._contract = _contract
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, x):
dtype = x.dtype
x = x.float()
residual = x
with torch.autocast(device_type="cuda", enabled=False):
x = self.forward_transform(x)
if self.scale_residual:
residual = self.inverse_transform(x)
x = torch.view_as_real(x)
x = self._contract(x, self.weight)
x = torch.view_as_complex(x)
with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x)
if hasattr(self, "bias"):
x = x + self.bias
x = x.type(dtype)
return x, residual
class FactorizedSpectralConvS2(nn.Module):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def __init__(self,
forward_transform,
inverse_transform,
in_channels,
out_channels,
gain = 2.,
operator_type = "driscoll-healy",
rank = 0.2,
factorization = None,
separable = False,
implementation = "factorized",
decomposition_kwargs=dict(),
bias = False):
super(SpectralConvS2, self).__init__()
self.forward_transform = forward_transform
self.inverse_transform = inverse_transform
self.modes_lat = self.inverse_transform.lmax
self.modes_lon = self.inverse_transform.mmax
self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) \
or (self.forward_transform.nlon != self.inverse_transform.nlon)
# Make sure we are using a Complex Factorized Tensor
if factorization is None:
factorization = "Dense" # No factorization
if not factorization.lower().startswith("complex"):
factorization = f"Complex{factorization}"
# remember factorization details
self.operator_type = operator_type
self.rank = rank
self.factorization = factorization
self.separable = separable
assert self.inverse_transform.lmax == self.modes_lat
assert self.inverse_transform.mmax == self.modes_lon
weight_shape = [out_channels]
if not self.separable:
weight_shape += [in_channels]
if self.operator_type == "diagonal":
weight_shape += [self.modes_lat, self.modes_lon]
elif self.operator_type == "block-diagonal":
weight_shape += [self.modes_lat, self.modes_lon, self.modes_lon]
elif self.operator_type == "driscoll-healy":
weight_shape += [self.modes_lat]
else:
raise NotImplementedError(f"Unkonw operator type f{self.operator_type}")
# form weight tensors
self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, factorization=factorization,
fixed_rank_modes=False, **decomposition_kwargs)
# initialization of weights
scale = math.sqrt(gain / in_channels) scale = math.sqrt(gain / in_channels)
self.weight.normal_(0, scale) self.weight = nn.Parameter(scale * torch.randn(*weight_shape, dtype=torch.complex64))
# get the right contraction function
from .factorizations import get_contract_fun
self._contract = get_contract_fun(self.weight, implementation=implementation, separable=separable)
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
...@@ -390,7 +278,7 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -390,7 +278,7 @@ class FactorizedSpectralConvS2(nn.Module):
if self.scale_residual: if self.scale_residual:
residual = self.inverse_transform(x) residual = self.inverse_transform(x)
x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) x = torch.einsum(self.contract_func, x, self.weight)
with torch.autocast(device_type="cuda", enabled=False): with torch.autocast(device_type="cuda", enabled=False):
x = self.inverse_transform(x) x = self.inverse_transform(x)
...@@ -399,4 +287,4 @@ class FactorizedSpectralConvS2(nn.Module): ...@@ -399,4 +287,4 @@ class FactorizedSpectralConvS2(nn.Module):
x = x + self.bias x = x + self.bias
x = x.type(dtype) x = x.type(dtype)
return x, residual return x, residual
\ No newline at end of file
...@@ -51,6 +51,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -51,6 +51,7 @@ class DiscreteContinuousEncoder(nn.Module):
inp_chans=2, inp_chans=2,
out_chans=2, out_chans=2,
kernel_shape=[3, 4], kernel_shape=[3, 4],
basis_type="piecewise linear",
groups=1, groups=1,
bias=False, bias=False,
): ):
...@@ -63,11 +64,12 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -63,11 +64,12 @@ class DiscreteContinuousEncoder(nn.Module):
in_shape=inp_shape, in_shape=inp_shape,
out_shape=out_shape, out_shape=out_shape,
kernel_shape=kernel_shape, kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_in, grid_in=grid_in,
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=bias, bias=bias,
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1), theta_cutoff=4*math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
) )
def forward(self, x): def forward(self, x):
...@@ -91,6 +93,7 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -91,6 +93,7 @@ class DiscreteContinuousDecoder(nn.Module):
inp_chans=2, inp_chans=2,
out_chans=2, out_chans=2,
kernel_shape=[3, 4], kernel_shape=[3, 4],
basis_type="piecewise linear",
groups=1, groups=1,
bias=False, bias=False,
): ):
...@@ -107,11 +110,12 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -107,11 +110,12 @@ class DiscreteContinuousDecoder(nn.Module):
in_shape=out_shape, in_shape=out_shape,
out_shape=out_shape, out_shape=out_shape,
kernel_shape=kernel_shape, kernel_shape=kernel_shape,
basis_type=basis_type,
grid_in=grid_out, grid_in=grid_out,
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=False, bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1), theta_cutoff=4*math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
) )
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False) # self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
...@@ -131,58 +135,6 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -131,58 +135,6 @@ class DiscreteContinuousDecoder(nn.Module):
return x return x
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
bias=bias,
)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalNeuralOperatorBlock(nn.Module): class SphericalNeuralOperatorBlock(nn.Module):
""" """
...@@ -202,13 +154,11 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -202,13 +154,11 @@ class SphericalNeuralOperatorBlock(nn.Module):
drop_path=0.0, drop_path=0.0,
act_layer=nn.ReLU, act_layer=nn.ReLU,
norm_layer=nn.Identity, norm_layer=nn.Identity,
factorization=None,
separable=False,
rank=128,
inner_skip="None", inner_skip="None",
outer_skip="linear", outer_skip="linear",
use_mlp=True, use_mlp=True,
disco_kernel_shape=[2, 4], disco_kernel_shape=[2, 4],
disco_basis_type="piecewise linear",
): ):
super().__init__() super().__init__()
...@@ -228,25 +178,14 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -228,25 +178,14 @@ class SphericalNeuralOperatorBlock(nn.Module):
in_shape=(forward_transform.nlat, forward_transform.nlon), in_shape=(forward_transform.nlat, forward_transform.nlon),
out_shape=(inverse_transform.nlat, inverse_transform.nlon), out_shape=(inverse_transform.nlat, inverse_transform.nlon),
kernel_shape=disco_kernel_shape, kernel_shape=disco_kernel_shape,
basis_type=disco_basis_type,
grid_in=forward_transform.grid, grid_in=forward_transform.grid,
grid_out=inverse_transform.grid, grid_out=inverse_transform.grid,
bias=False, bias=False,
theta_cutoff=(disco_kernel_shape[0] + 1) * torch.pi / float(forward_transform.nlat - 1) / math.sqrt(2), theta_cutoff=4*math.sqrt(2) * torch.pi / float(inverse_transform.nlat - 1),
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralFilterLayer( self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=False,
)
else: else:
raise ValueError(f"Unknown convolution type {conv_type}") raise ValueError(f"Unknown convolution type {conv_type}")
...@@ -261,8 +200,6 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -261,8 +200,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {inner_skip}") raise ValueError(f"Unknown skip connection type {inner_skip}")
self.act_layer = act_layer()
# first normalisation layer # first normalisation layer
self.norm0 = norm_layer() self.norm0 = norm_layer()
...@@ -313,9 +250,6 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -313,9 +250,6 @@ class SphericalNeuralOperatorBlock(nn.Module):
if hasattr(self, "inner_skip"): if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"):
x = self.act_layer(x)
if hasattr(self, "mlp"): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
...@@ -331,11 +265,13 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -331,11 +265,13 @@ class SphericalNeuralOperatorBlock(nn.Module):
class LocalSphericalNeuralOperatorNet(nn.Module): class LocalSphericalNeuralOperatorNet(nn.Module):
""" """
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO, LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
both linear and non-linear variants. operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
as well as in the encoder and decoders.
Parameters Parameters
---------- -----------
spectral_transform : str, optional spectral_transform : str, optional
Type of spectral transformation to use, by default "sht" Type of spectral transformation to use, by default "sht"
operator_type : str, optional operator_type : str, optional
...@@ -373,17 +309,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -373,17 +309,11 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
rank : float, optional rank : float, optional
Rank of the approximation, by default 1.0 Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
pos_embed : bool, optional pos_embed : bool, optional
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
Example: Example
-------- -----------
>>> model = SphericalFourierNeuralOperatorNet( >>> model = SphericalFourierNeuralOperatorNet(
... img_shape=(128, 256), ... img_shape=(128, 256),
... scale_factor=4, ... scale_factor=4,
...@@ -394,6 +324,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -394,6 +324,17 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
... use_mlp=True,) ... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape >>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References
-----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
.. [2] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
""" """
def __init__( def __init__(
...@@ -402,6 +343,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -402,6 +343,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
operator_type="driscoll-healy", operator_type="driscoll-healy",
img_size=(128, 256), img_size=(128, 256),
grid="equiangular", grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=4, scale_factor=4,
in_chans=3, in_chans=3,
out_chans=3, out_chans=3,
...@@ -410,6 +352,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -410,6 +352,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
activation_function="relu", activation_function="relu",
kernel_shape=[3, 4], kernel_shape=[3, 4],
encoder_kernel_shape=[3, 4], encoder_kernel_shape=[3, 4],
disco_basis_type="piecewise linear",
use_mlp=True, use_mlp=True,
mlp_ratio=2.0, mlp_ratio=2.0,
drop_rate=0.0, drop_rate=0.0,
...@@ -418,9 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -418,9 +361,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
hard_thresholding_fraction=1.0, hard_thresholding_fraction=1.0,
use_complex_kernels=True, use_complex_kernels=True,
big_skip=False, big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False, pos_embed=False,
): ):
super().__init__() super().__init__()
...@@ -429,6 +369,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -429,6 +369,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.operator_type = operator_type self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.in_chans = in_chans self.in_chans = in_chans
self.out_chans = out_chans self.out_chans = out_chans
...@@ -439,9 +380,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -439,9 +380,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.normalization_layer = normalization_layer self.normalization_layer = normalization_layer
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.big_skip = big_skip self.big_skip = big_skip
self.factorization = factorization
self.separable = (separable,)
self.rank = rank
# activation function # activation function
if activation_function == "relu": if activation_function == "relu":
...@@ -455,7 +393,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -455,7 +393,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles # compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor self.w = self.img_size[1] // scale_factor
# dropout # dropout
...@@ -494,9 +432,10 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -494,9 +432,10 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.img_size, self.img_size,
(self.h, self.w), (self.h, self.w),
self.encoder_kernel_shape, self.encoder_kernel_shape,
basis_type=disco_basis_type,
groups=1, groups=1,
grid_in=grid, grid_in=grid,
grid_out="legendre-gauss", grid_out=grid_internal,
bias=False, bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1), theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
) )
...@@ -506,7 +445,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -506,7 +445,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# inp_shape=self.img_size, # inp_shape=self.img_size,
# out_shape=(self.h, self.w), # out_shape=(self.h, self.w),
# grid_in=grid, # grid_in=grid,
# grid_out="legendre-gauss", # grid_out=grid_internal,
# inp_chans=self.in_chans, # inp_chans=self.in_chans,
# out_chans=self.embed_dim, # out_chans=self.embed_dim,
# kernel_shape=self.encoder_kernel_shape, # kernel_shape=self.encoder_kernel_shape,
...@@ -520,8 +459,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -520,8 +459,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction) modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon) modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
else: else:
raise (ValueError("Unknown spectral transform")) raise (ValueError("Unknown spectral transform"))
...@@ -556,10 +495,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -556,10 +495,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
inner_skip=inner_skip, inner_skip=inner_skip,
outer_skip=outer_skip, outer_skip=outer_skip,
use_mlp=use_mlp, use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
disco_kernel_shape=kernel_shape, disco_kernel_shape=kernel_shape,
disco_basis_type=disco_basis_type,
) )
self.blocks.append(block) self.blocks.append(block)
...@@ -582,11 +519,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -582,11 +519,12 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.decoder = DiscreteContinuousDecoder( self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w), inp_shape=(self.h, self.w),
out_shape=self.img_size, out_shape=self.img_size,
grid_in="legendre-gauss", grid_in=grid_internal,
grid_out=grid, grid_out=grid,
inp_chans=self.embed_dim, inp_chans=self.embed_dim,
out_chans=self.out_chans, out_chans=self.out_chans,
kernel_shape=self.encoder_kernel_shape, kernel_shape=self.encoder_kernel_shape,
basis_type=disco_basis_type,
groups=1, groups=1,
bias=False, bias=False,
) )
......
...@@ -39,51 +39,6 @@ from .layers import * ...@@ -39,51 +39,6 @@ from .layers import *
from functools import partial from functools import partial
class SpectralFilterLayer(nn.Module):
"""
Fourier layer. Contains the convolution part of the FNO/SFNO
"""
def __init__(
self,
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=2.0,
operator_type="diagonal",
hidden_size_factor=2,
factorization=None,
separable=False,
rank=1e-2,
bias=True,
):
super(SpectralFilterLayer, self).__init__()
if factorization is None:
self.filter = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain, operator_type=operator_type, bias=bias)
elif factorization is not None:
self.filter = FactorizedSpectralConvS2(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain,
operator_type=operator_type,
rank=rank,
factorization=factorization,
separable=separable,
bias=bias,
)
else:
raise (NotImplementedError)
def forward(self, x):
return self.filter(x)
class SphericalFourierNeuralOperatorBlock(nn.Module): class SphericalFourierNeuralOperatorBlock(nn.Module):
""" """
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks. Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
...@@ -108,7 +63,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -108,7 +63,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
outer_skip=None, outer_skip=None,
use_mlp=True, use_mlp=True,
): ):
super(SphericalFourierNeuralOperatorBlock, self).__init__() super().__init__()
if act_layer == nn.Identity: if act_layer == nn.Identity:
gain_factor = 1.0 gain_factor = 1.0
...@@ -118,20 +73,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -118,20 +73,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
if inner_skip == "linear" or inner_skip == "identity": if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0 gain_factor /= 2.0
# convolution layer self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
self.filter = SpectralFilterLayer(
forward_transform,
inverse_transform,
input_dim,
output_dim,
gain=gain_factor,
operator_type=operator_type,
hidden_size_factor=mlp_ratio,
factorization=factorization,
separable=separable,
rank=rank,
bias=True,
)
if inner_skip == "linear": if inner_skip == "linear":
self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1) self.inner_skip = nn.Conv2d(input_dim, output_dim, 1, 1)
...@@ -144,8 +86,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -144,8 +86,6 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else: else:
raise ValueError(f"Unknown skip connection type {inner_skip}") raise ValueError(f"Unknown skip connection type {inner_skip}")
self.act_layer = act_layer()
# first normalisation layer # first normalisation layer
self.norm0 = norm_layer() self.norm0 = norm_layer()
...@@ -178,16 +118,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -178,16 +118,13 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
def forward(self, x): def forward(self, x):
x, residual = self.filter(x) x, residual = self.global_conv(x)
x = self.norm0(x) x = self.norm0(x)
if hasattr(self, "inner_skip"): if hasattr(self, "inner_skip"):
x = x + self.inner_skip(residual) x = x + self.inner_skip(residual)
if hasattr(self, "act_layer"):
x = self.act_layer(x)
if hasattr(self, "mlp"): if hasattr(self, "mlp"):
x = self.mlp(x) x = self.mlp(x)
...@@ -203,13 +140,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module): ...@@ -203,13 +140,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
class SphericalFourierNeuralOperatorNet(nn.Module): class SphericalFourierNeuralOperatorNet(nn.Module):
""" """
SphericalFourierNeuralOperator module. Can use both FFTs and SHTs to represent either FNO or SFNO, SphericalFourierNeuralOperator module. Implements the 'linear' variant of the Spherical Fourier Neural Operator
both linear and non-linear variants. as presented in [1]. Spherical convolutions are applied via spectral transforms to apply a geometrically consistent
and approximately equivariant architecture.
Parameters Parameters
---------- ----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy" Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional img_shape : tuple, optional
...@@ -244,12 +180,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -244,12 +180,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
rank : float, optional rank : float, optional
Rank of the approximation, by default 1.0 Rank of the approximation, by default 1.0
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
Whether to use separable convolutions, by default False
rank : (int, Tuple[int]), optional
If a factorization is used, which rank to use. Argument is passed to tensorly
pos_embed : bool, optional pos_embed : bool, optional
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
...@@ -265,14 +195,20 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -265,14 +195,20 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
... use_mlp=True,) ... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape >>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256]) torch.Size([1, 2, 128, 256])
References
-----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
""" """
def __init__( def __init__(
self, self,
spectral_transform="sht",
operator_type="driscoll-healy", operator_type="driscoll-healy",
img_size=(128, 256), img_size=(128, 256),
grid="equiangular", grid="equiangular",
grid_internal="legendre-gauss",
scale_factor=3, scale_factor=3,
in_chans=3, in_chans=3,
out_chans=3, out_chans=3,
...@@ -288,18 +224,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -288,18 +224,15 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
hard_thresholding_fraction=1.0, hard_thresholding_fraction=1.0,
use_complex_kernels=True, use_complex_kernels=True,
big_skip=False, big_skip=False,
factorization=None,
separable=False,
rank=128,
pos_embed=False, pos_embed=False,
): ):
super(SphericalFourierNeuralOperatorNet, self).__init__() super().__init__()
self.spectral_transform = spectral_transform
self.operator_type = operator_type self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid self.grid = grid
self.grid_internal = grid_internal
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.in_chans = in_chans self.in_chans = in_chans
self.out_chans = out_chans self.out_chans = out_chans
...@@ -310,9 +243,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -310,9 +243,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self.use_mlp = use_mlp self.use_mlp = use_mlp
self.encoder_layers = encoder_layers self.encoder_layers = encoder_layers
self.big_skip = big_skip self.big_skip = big_skip
self.factorization = factorization
self.separable = (separable,)
self.rank = rank
# activation function # activation function
if activation_function == "relu": if activation_function == "relu":
...@@ -326,7 +256,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -326,7 +256,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size. We assume that the latitude-grid includes both poles # compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor self.h = (self.img_size[0] - 1) // scale_factor + 1
self.w = self.img_size[1] // scale_factor self.w = self.img_size[1] // scale_factor
# dropout # dropout
...@@ -381,30 +311,17 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -381,30 +311,17 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers.append(fc) encoder_layers.append(fc)
self.encoder = nn.Sequential(*encoder_layers) self.encoder = nn.Sequential(*encoder_layers)
# prepare the spectral transform # compute the modes for the sht
if self.spectral_transform == "sht": modes_lat = self.h
# due to some spectral artifacts with cufft, we substract one mode here
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lon = (self.w // 2 + 1) -1
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float() modes_lat = modes_lon = int(min(modes_lat, modes_lon) * self.hard_thresholding_fraction)
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float()
elif self.spectral_transform == "fft": self.trans_down = RealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
self.itrans_up = InverseRealSHT(*self.img_size, lmax=modes_lat, mmax=modes_lon, grid=self.grid).float()
modes_lat = int(self.h * self.hard_thresholding_fraction) self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.trans_down = RealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
else:
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -439,9 +356,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -439,9 +356,6 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
inner_skip=inner_skip, inner_skip=inner_skip,
outer_skip=outer_skip, outer_skip=outer_skip,
use_mlp=use_mlp, use_mlp=use_mlp,
factorization=self.factorization,
separable=self.separable,
rank=self.rank,
) )
self.blocks.append(block) self.blocks.append(block)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: # modification, are permitted provided that the following conditions are met:
# #
...@@ -35,10 +35,23 @@ from math import ceil ...@@ -35,10 +35,23 @@ from math import ceil
from ...shallow_water_equations import ShallowWaterSolver from ...shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset): class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data""" """Custom Dataset class for PDE training data"""
def __init__(self, dt, nsteps, dims=(384, 768), pde='shallow water equations', initial_condition='random',
num_examples=32, device=torch.device('cpu'), normalize=True, stream=None): def __init__(
self,
dt,
nsteps,
dims=(384, 768),
grid="equiangular",
pde="shallow water equations",
initial_condition="random",
num_examples=32,
device=torch.device("cpu"),
normalize=True,
stream=None,
):
self.num_examples = num_examples self.num_examples = num_examples
self.device = device self.device = device
self.stream = stream self.stream = stream
...@@ -50,11 +63,11 @@ class PdeDataset(torch.utils.data.Dataset): ...@@ -50,11 +63,11 @@ class PdeDataset(torch.utils.data.Dataset):
self.nsteps = nsteps self.nsteps = nsteps
self.normalize = normalize self.normalize = normalize
if pde == 'shallow water equations': if pde == "shallow water equations":
lmax = ceil(self.nlat/3) lmax = ceil(self.nlat / 3)
mmax = lmax mmax = lmax
dt_solver = dt / float(self.nsteps) dt_solver = dt / float(self.nsteps)
self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid='equiangular').to(self.device).float() self.solver = ShallowWaterSolver(self.nlat, self.nlon, dt_solver, lmax=lmax, mmax=mmax, grid=grid).to(self.device).float()
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -66,25 +79,25 @@ class PdeDataset(torch.utils.data.Dataset): ...@@ -66,25 +79,25 @@ class PdeDataset(torch.utils.data.Dataset):
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1) self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
def __len__(self): def __len__(self):
length = self.num_examples if self.ictype == 'random' else 1 length = self.num_examples if self.ictype == "random" else 1
return length return length
def set_initial_condition(self, ictype='random'): def set_initial_condition(self, ictype="random"):
self.ictype = ictype self.ictype = ictype
def set_num_examples(self, num_examples=32): def set_num_examples(self, num_examples=32):
self.num_examples = num_examples self.num_examples = num_examples
def _get_sample(self): def _get_sample(self):
if self.ictype == 'random': if self.ictype == "random":
inp = self.solver.random_initial_condition(mach=0.2) inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == 'galewsky': elif self.ictype == "galewsky":
inp = self.solver.galewsky_initial_condition() inp = self.solver.galewsky_initial_condition()
# solve pde for n steps to return the target # solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps) tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp) inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar) tar = self.solver.spec2grid(tar)
return inp, tar return inp, tar
......
...@@ -367,6 +367,21 @@ class ShallowWaterSolver(nn.Module): ...@@ -367,6 +367,21 @@ class ShallowWaterSolver(nn.Module):
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05) plt.title(title, y=1.05)
elif projection == 'robinson':
import cartopy.crs as ccrs
proj = ccrs.Robinson(central_longitude=0.0)
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05)
else: else:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment