Commit 9de3131b authored by Boris Bonev's avatar Boris Bonev
Browse files

Moved SFNO examples into the library to make them available

parent 45b371b7
...@@ -45,15 +45,13 @@ import pandas as pd ...@@ -45,15 +45,13 @@ import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torch_harmonics.examples.sfno import PdeDataset
from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
# wandb logging # wandb logging
import wandb import wandb
wandb.login() wandb.login()
import sys
sys.path.append(os.path.join(os.path.dirname( __file__), "../"))
from pde_sphere import SphereSolver
def l2loss_sphere(solver, prd, tar, relative=False, squared=False): def l2loss_sphere(solver, prd, tar, relative=False, squared=False):
loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1) loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)
if relative: if relative:
...@@ -166,9 +164,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -166,9 +164,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(device.index) torch.cuda.set_device(device.index)
# dataset
from utils.pde_dataset import PdeDataset
# 1 hour prediction steps # 1 hour prediction steps
dt = 1*3600 dt = 1*3600
dt_solver = 150 dt_solver = 150
...@@ -337,8 +332,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False): ...@@ -337,8 +332,6 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
# from models.unet import UNet # from models.unet import UNet
# models['unet_baseline'] = partial(UNet) # models['unet_baseline'] = partial(UNet)
# SFNO and FNO models
from models.sfno import SphericalFourierNeuralOperatorNet as SFNO
# SFNO models # SFNO models
models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon), models['sfno_sc3_layer4_edim256_linear'] = partial(SFNO, spectral_transform='sht', filter_type='linear', img_size=(nlat, nlon),
num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector') num_layers=4, scale_factor=3, embed_dim=256, operator_type='vector')
......
images/sfno.gif

132 Bytes

This image diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -38,12 +38,6 @@ ...@@ -38,12 +38,6 @@
"\n", "\n",
"import time\n", "import time\n",
"\n", "\n",
"import sys\n",
"sys.path.append(\"../../examples\")\n",
"\n",
"# from torch_harmonics.sht import *\n",
"from pde_sphere import SphereSolver\n",
"\n",
"cmap='twilight_shifted'" "cmap='twilight_shifted'"
] ]
}, },
...@@ -153,7 +147,7 @@ ...@@ -153,7 +147,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from models.sfno import SphericalFourierNeuralOperatorNet as SFNO" "from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO"
] ]
}, },
{ {
......
...@@ -34,3 +34,4 @@ __version__ = '0.6.0' ...@@ -34,3 +34,4 @@ __version__ = '0.6.0'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature from . import quadrature
from . import random_fields from . import random_fields
import examples
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from .pde_sphere import SphereSolver
from .shallow_water_equations import ShallowWaterSolver
\ No newline at end of file
...@@ -40,8 +40,6 @@ import torch_harmonics as harmonics ...@@ -40,8 +40,6 @@ import torch_harmonics as harmonics
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.tri as mtri
try: try:
import cartopy.crs as ccrs import cartopy.crs as ccrs
...@@ -170,7 +168,7 @@ class SphereSolver(nn.Module): ...@@ -170,7 +168,7 @@ class SphereSolver(nn.Module):
if ccrs is None: if ccrs is None:
raise ImportError("Couldn't import Cartopy") raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=45.0) proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
#ax = plt.gca(projection=proj, frameon=True) #ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from .models.sfno import SphericalFourierNeuralOperatorNet
from .utils.pde_dataset import PdeDataset
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
\ No newline at end of file
...@@ -47,10 +47,10 @@ from typing import Optional ...@@ -47,10 +47,10 @@ from typing import Optional
import math import math
from torch_harmonics import * from torch_harmonics import *
from models.contractions import * from .contractions import *
from models.activations import * from .activations import *
from models.factorizations import get_contract_fun from .factorizations import get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations # # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl # import tensorly as tl
......
...@@ -35,7 +35,7 @@ from apex.normalization import FusedLayerNorm ...@@ -35,7 +35,7 @@ from apex.normalization import FusedLayerNorm
from torch_harmonics import * from torch_harmonics import *
from models.layers import * from .layers import *
class SpectralFilterLayer(nn.Module): class SpectralFilterLayer(nn.Module):
""" """
......
...@@ -365,7 +365,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -365,7 +365,7 @@ class ShallowWaterSolver(nn.Module):
if ccrs is None: if ccrs is None:
raise ImportError("Couldn't import Cartopy") raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=-45.0) proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
#ax = plt.gca(projection=proj, frameon=True) #ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment