Commit d81fbd34 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

changing default normalization mode in DISCO

parent 96a2b546
......@@ -430,21 +430,20 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none",
)
# models[f"lsno_sc2_layers4_e32"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid=grid,
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# normalization_layer="none",
# )
models[f"lsno_sc2_layers4_e32"] = partial(
LSNO,
img_size=(nlat, nlon),
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=False,
pos_embed=False,
use_mlp=True,
normalization_layer="none",
)
# iterate over models and train each model
root_path = os.path.dirname(__file__)
......@@ -487,7 +486,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time
......
This diff is collapsed.
......@@ -176,7 +176,7 @@
"# activation_function = nn.ReLU,\n",
"# bias = False):\n",
"# super().__init__()\n",
" \n",
"\n",
"# current_dim = input_dim\n",
"# layers = []\n",
"# for l in range(num_layers-1):\n",
......@@ -221,7 +221,7 @@
" loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)\n",
" if relative:\n",
" loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)\n",
" \n",
"\n",
" if not squared:\n",
" loss = torch.sqrt(loss)\n",
" loss = loss.mean()\n",
......@@ -515,7 +515,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
......@@ -531,12 +531,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
......
......@@ -57,7 +57,7 @@ except ImportError as err:
def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
):
"""
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
......@@ -135,7 +135,7 @@ def _precompute_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="none",
basis_norm_mode="mean",
merge_quadrature=False,
):
"""
......@@ -297,7 +297,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -403,7 +403,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......
......@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
transpose_normalization=False,
basis_norm_mode="none",
basis_norm_mode="mean",
merge_quadrature=False,
):
"""
......@@ -208,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......@@ -348,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
......
......@@ -35,6 +35,7 @@ import torch.amp as amp
from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import ResampleS2
from ._layers import *
......@@ -44,7 +45,7 @@ from functools import partial
class DiscreteContinuousEncoder(nn.Module):
def __init__(
self,
inp_shape=(721, 1440),
in_shape=(721, 1440),
out_shape=(480, 960),
grid_in="equiangular",
grid_out="equiangular",
......@@ -61,7 +62,7 @@ class DiscreteContinuousEncoder(nn.Module):
self.conv = DiscreteContinuousConvS2(
inp_chans,
out_chans,
in_shape=inp_shape,
in_shape=in_shape,
out_shape=out_shape,
kernel_shape=kernel_shape,
basis_type=basis_type,
......@@ -69,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1),
theta_cutoff=math.sqrt(2.0) * torch.pi / float(out_shape[0] - 1),
)
def forward(self, x):
......@@ -86,7 +87,7 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
def __init__(
self,
inp_shape=(480, 960),
in_shape=(480, 960),
out_shape=(721, 1440),
grid_in="equiangular",
grid_out="equiangular",
......@@ -99,12 +100,13 @@ class DiscreteContinuousDecoder(nn.Module):
):
super().__init__()
# set up
self.sht = RealSHT(*inp_shape, grid=grid_in).float()
# # set up
self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
# self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution
self.convt = DiscreteContinuousConvTransposeS2(
self.conv = DiscreteContinuousConvS2(
inp_chans,
out_chans,
in_shape=out_shape,
......@@ -115,21 +117,22 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1),
theta_cutoff=math.sqrt(2.0) * torch.pi / float(in_shape[0] - 1),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def _upscale_sht(self, x: torch.Tensor):
def upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x))
def forward(self, x):
dtype = x.dtype
# x = self.upscale(x)
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
x = self._upscale_sht(x)
x = self.convt(x)
x = self.upscale_sht(x)
x = self.conv(x)
x = x.to(dtype=dtype)
return x
......@@ -182,7 +185,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in=forward_transform.grid,
grid_out=inverse_transform.grid,
bias=False,
theta_cutoff=4*math.sqrt(2) * torch.pi / float(inverse_transform.nlat - 1),
theta_cutoff=4 * math.sqrt(2.0) * torch.pi / float(inverse_transform.nlat - 1),
)
elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
......@@ -272,8 +275,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Parameters
-----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional
......@@ -339,7 +340,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
def __init__(
self,
spectral_transform="sht",
operator_type="driscoll-healy",
img_size=(128, 256),
grid="equiangular",
......@@ -365,7 +365,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
):
super().__init__()
self.spectral_transform = spectral_transform
self.operator_type = operator_type
self.img_size = img_size
self.grid = grid
......@@ -440,8 +439,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
)
# prepare the spectral transform
if self.spectral_transform == "sht":
# prepare the SHT
modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lat = modes_lon = min(modes_lat, modes_lon)
......@@ -449,9 +447,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
else:
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([])
for i in range(self.num_layers):
first_layer = i == 0
......@@ -490,7 +485,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# decoder
self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w),
in_shape=(self.h, self.w),
out_shape=self.img_size,
grid_in=grid_internal,
grid_out=grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment