Commit d81fbd34 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

changing default normalization mode in DISCO

parent 96a2b546
...@@ -430,21 +430,20 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -430,21 +430,20 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none", normalization_layer="none",
) )
# models[f"lsno_sc2_layers4_e32"] = partial( models[f"lsno_sc2_layers4_e32"] = partial(
# LSNO, LSNO,
# spectral_transform="sht", img_size=(nlat, nlon),
# img_size=(nlat, nlon), grid=grid,
# grid=grid, num_layers=4,
# num_layers=4, scale_factor=2,
# scale_factor=2, embed_dim=32,
# embed_dim=32, operator_type="driscoll-healy",
# operator_type="driscoll-healy", activation_function="gelu",
# activation_function="gelu", big_skip=False,
# big_skip=True, pos_embed=False,
# pos_embed=False, use_mlp=True,
# use_mlp=True, normalization_layer="none",
# normalization_layer="none", )
# )
# iterate over models and train each model # iterate over models and train each model
root_path = os.path.dirname(__file__) root_path = os.path.dirname(__file__)
...@@ -487,7 +486,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -487,7 +486,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
gscaler = amp.GradScaler(enabled=enable_amp) gscaler = amp.GradScaler(enabled=enable_amp)
dataloader.dataset.nsteps = 2 * dt//dt_solver dataloader.dataset.nsteps = 2 * dt//dt_solver
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=20, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=10, loss_fn="l2", nfuture=nfuture, enable_amp=enable_amp, log_grads=log_grads)
dataloader.dataset.nsteps = 1 * dt//dt_solver dataloader.dataset.nsteps = 1 * dt//dt_solver
training_time = time.time() - start_time training_time = time.time() - start_time
......
This diff is collapsed.
...@@ -176,7 +176,7 @@ ...@@ -176,7 +176,7 @@
"# activation_function = nn.ReLU,\n", "# activation_function = nn.ReLU,\n",
"# bias = False):\n", "# bias = False):\n",
"# super().__init__()\n", "# super().__init__()\n",
" \n", "\n",
"# current_dim = input_dim\n", "# current_dim = input_dim\n",
"# layers = []\n", "# layers = []\n",
"# for l in range(num_layers-1):\n", "# for l in range(num_layers-1):\n",
...@@ -221,7 +221,7 @@ ...@@ -221,7 +221,7 @@
" loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)\n", " loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)\n",
" if relative:\n", " if relative:\n",
" loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)\n", " loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)\n",
" \n", "\n",
" if not squared:\n", " if not squared:\n",
" loss = torch.sqrt(loss)\n", " loss = torch.sqrt(loss)\n",
" loss = loss.mean()\n", " loss = loss.mean()\n",
...@@ -515,7 +515,7 @@ ...@@ -515,7 +515,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -531,12 +531,7 @@ ...@@ -531,12 +531,7 @@
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.10.12"
}, },
"orig_nbformat": 4, "orig_nbformat": 4
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
......
...@@ -57,7 +57,7 @@ except ImportError as err: ...@@ -57,7 +57,7 @@ except ImportError as err:
def _normalize_convolution_tensor_s2( def _normalize_convolution_tensor_s2(
psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9 psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
): ):
""" """
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes: Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
...@@ -135,7 +135,7 @@ def _precompute_convolution_tensor_s2( ...@@ -135,7 +135,7 @@ def _precompute_convolution_tensor_s2(
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="none", basis_norm_mode="mean",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """
...@@ -297,7 +297,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -297,7 +297,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
...@@ -403,7 +403,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -403,7 +403,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
......
...@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out="equiangular", grid_out="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff=0.01 * math.pi,
transpose_normalization=False, transpose_normalization=False,
basis_norm_mode="none", basis_norm_mode="mean",
merge_quadrature=False, merge_quadrature=False,
): ):
""" """
...@@ -208,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -208,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
...@@ -348,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -348,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, List[int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "none", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
grid_in: Optional[str] = "equiangular", grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular", grid_out: Optional[str] = "equiangular",
......
...@@ -35,6 +35,7 @@ import torch.amp as amp ...@@ -35,6 +35,7 @@ import torch.amp as amp
from torch_harmonics import RealSHT, InverseRealSHT from torch_harmonics import RealSHT, InverseRealSHT
from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from torch_harmonics import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics import ResampleS2
from ._layers import * from ._layers import *
...@@ -44,7 +45,7 @@ from functools import partial ...@@ -44,7 +45,7 @@ from functools import partial
class DiscreteContinuousEncoder(nn.Module): class DiscreteContinuousEncoder(nn.Module):
def __init__( def __init__(
self, self,
inp_shape=(721, 1440), in_shape=(721, 1440),
out_shape=(480, 960), out_shape=(480, 960),
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
...@@ -61,7 +62,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -61,7 +62,7 @@ class DiscreteContinuousEncoder(nn.Module):
self.conv = DiscreteContinuousConvS2( self.conv = DiscreteContinuousConvS2(
inp_chans, inp_chans,
out_chans, out_chans,
in_shape=inp_shape, in_shape=in_shape,
out_shape=out_shape, out_shape=out_shape,
kernel_shape=kernel_shape, kernel_shape=kernel_shape,
basis_type=basis_type, basis_type=basis_type,
...@@ -69,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -69,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=bias, bias=bias,
theta_cutoff=math.sqrt(2) * torch.pi / float(out_shape[0] - 1), theta_cutoff=math.sqrt(2.0) * torch.pi / float(out_shape[0] - 1),
) )
def forward(self, x): def forward(self, x):
...@@ -86,7 +87,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -86,7 +87,7 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module): class DiscreteContinuousDecoder(nn.Module):
def __init__( def __init__(
self, self,
inp_shape=(480, 960), in_shape=(480, 960),
out_shape=(721, 1440), out_shape=(721, 1440),
grid_in="equiangular", grid_in="equiangular",
grid_out="equiangular", grid_out="equiangular",
...@@ -99,12 +100,13 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -99,12 +100,13 @@ class DiscreteContinuousDecoder(nn.Module):
): ):
super().__init__() super().__init__()
# set up # # set up
self.sht = RealSHT(*inp_shape, grid=grid_in).float() self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float() self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
# self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution # set up DISCO convolution
self.convt = DiscreteContinuousConvTransposeS2( self.conv = DiscreteContinuousConvS2(
inp_chans, inp_chans,
out_chans, out_chans,
in_shape=out_shape, in_shape=out_shape,
...@@ -115,21 +117,22 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -115,21 +117,22 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=False, bias=False,
theta_cutoff=math.sqrt(2) * torch.pi / float(inp_shape[0] - 1), theta_cutoff=math.sqrt(2.0) * torch.pi / float(in_shape[0] - 1),
) )
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False) # self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def _upscale_sht(self, x: torch.Tensor): def upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x)) return self.isht(self.sht(x))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
# x = self.upscale(x)
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
x = x.float() x = x.float()
x = self._upscale_sht(x) x = self.upscale_sht(x)
x = self.convt(x) x = self.conv(x)
x = x.to(dtype=dtype) x = x.to(dtype=dtype)
return x return x
...@@ -182,7 +185,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -182,7 +185,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in=forward_transform.grid, grid_in=forward_transform.grid,
grid_out=inverse_transform.grid, grid_out=inverse_transform.grid,
bias=False, bias=False,
theta_cutoff=4*math.sqrt(2) * torch.pi / float(inverse_transform.nlat - 1), theta_cutoff=4 * math.sqrt(2.0) * torch.pi / float(inverse_transform.nlat - 1),
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
...@@ -272,8 +275,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -272,8 +275,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Parameters Parameters
----------- -----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy" Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional img_shape : tuple, optional
...@@ -339,7 +340,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -339,7 +340,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
def __init__( def __init__(
self, self,
spectral_transform="sht",
operator_type="driscoll-healy", operator_type="driscoll-healy",
img_size=(128, 256), img_size=(128, 256),
grid="equiangular", grid="equiangular",
...@@ -365,7 +365,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -365,7 +365,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
): ):
super().__init__() super().__init__()
self.spectral_transform = spectral_transform
self.operator_type = operator_type self.operator_type = operator_type
self.img_size = img_size self.img_size = img_size
self.grid = grid self.grid = grid
...@@ -440,17 +439,13 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -440,17 +439,13 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1), theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
) )
# prepare the spectral transform # prepare the SHT
if self.spectral_transform == "sht": modes_lat = int(self.h * self.hard_thresholding_fraction)
modes_lat = int(self.h * self.hard_thresholding_fraction) modes_lon = int(self.w // 2 * self.hard_thresholding_fraction)
modes_lon = int(self.w // 2 * self.hard_thresholding_fraction) modes_lat = modes_lon = min(modes_lat, modes_lon)
modes_lat = modes_lon = min(modes_lat, modes_lon)
self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float() self.trans = RealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float() self.itrans = InverseRealSHT(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=grid_internal).float()
else:
raise (ValueError("Unknown spectral transform"))
self.blocks = nn.ModuleList([]) self.blocks = nn.ModuleList([])
for i in range(self.num_layers): for i in range(self.num_layers):
...@@ -490,7 +485,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -490,7 +485,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# decoder # decoder
self.decoder = DiscreteContinuousDecoder( self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w), in_shape=(self.h, self.w),
out_shape=self.img_size, out_shape=self.img_size,
grid_in=grid_internal, grid_in=grid_internal,
grid_out=grid, grid_out=grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment