Unverified Commit b3816ebc authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Setting imaginary parts of DCT and Nyquist frequencies to 0 in IRFFT (#72)

* setting imaginary parts of DCT and nyquist frequency to zero in IRSHT variants
parent dca116b5
...@@ -50,8 +50,6 @@ from torch_harmonics import RealSHT ...@@ -50,8 +50,6 @@ from torch_harmonics import RealSHT
# wandb logging # wandb logging
import wandb import wandb
wandb.login()
def l2loss_sphere(solver, prd, tar, relative=False, squared=True): def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1) loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
...@@ -265,7 +263,7 @@ def log_weights_and_grads(model, iters=1): ...@@ -265,7 +263,7 @@ def log_weights_and_grads(model, iters=1):
""" """
Helper routine intended for debugging purposes Helper routine intended for debugging purposes
""" """
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads") root_path = os.path.join(os.getcwd(), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar") weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname) print(weights_and_grads_fname)
...@@ -381,6 +379,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -381,6 +379,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.manual_seed(333) torch.manual_seed(333)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
# login
wandb.login()
# set parameters # set parameters
nfuture=0 nfuture=0
...@@ -394,7 +395,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -394,7 +395,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150 dt_solver = 150
nsteps = dt // dt_solver nsteps = dt // dt_solver
grid = "legendre-gauss" grid = "legendre-gauss"
nlat, nlon =(181, 360) nlat, nlon = (257, 512)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True) dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device) dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment # There is still an issue with parallel dataloading. Do NOT use it at the moment
...@@ -441,8 +442,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -441,8 +442,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=True,
normalization_layer="none", normalization_layer="none",
kernel_shape=[2, 2], kernel_shape=(2, 2),
encoder_kernel_shape=[2, 2], encoder_kernel_shape=(2, 2),
filter_basis_type="morlet", filter_basis_type="morlet",
upsample_sht = True, upsample_sht = True,
) )
...@@ -459,14 +460,14 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -459,14 +460,14 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed=False, pos_embed=False,
use_mlp=True, use_mlp=True,
normalization_layer="none", normalization_layer="none",
kernel_shape=[4], kernel_shape=(4),
encoder_kernel_shape=[4], encoder_kernel_shape=(4),
filter_basis_type="zernike", filter_basis_type="zernike",
upsample_sht = True, upsample_sht = True,
) )
# iterate over models and train each model # iterate over models and train each model
root_path = os.path.dirname(__file__) root_path = os.getcwd()
for model_name, model_handle in models.items(): for model_name, model_handle in models.items():
model = model_handle().to(device) model = model_handle().to(device)
...@@ -498,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -498,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time = time.time() start_time = time.time()
print(f"Training {model_name}, single step") print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=1, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads) train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0: if nfuture > 0:
print(f'Training {model_name}, {nfuture} step') print(f'Training {model_name}, {nfuture} step')
......
...@@ -248,9 +248,6 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -248,9 +248,6 @@ class DistributedInverseRealSHT(nn.Module):
# einsum # einsum
xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous() xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
#rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
#im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
#xs = torch.stack((rl, im), -1).contiguous()
# inverse FFT # inverse FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
...@@ -263,6 +260,11 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -263,6 +260,11 @@ class DistributedInverseRealSHT(nn.Module):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes) x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
# set DCT and nyquist frequencies to 0:
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
...@@ -528,6 +530,11 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -528,6 +530,11 @@ class DistributedInverseRealVectorSHT(nn.Module):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes) x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
# set DCT and nyquist frequencies to zero
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.quadrature import * from torch_harmonics.quadrature import _precompute_longitudes
import math import math
import numpy as np import numpy as np
......
...@@ -195,13 +195,18 @@ class InverseRealSHT(nn.Module): ...@@ -195,13 +195,18 @@ class InverseRealSHT(nn.Module):
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x) x = torch.view_as_real(x)
xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x return x
...@@ -395,6 +400,14 @@ class InverseRealVectorSHT(nn.Module): ...@@ -395,6 +400,14 @@ class InverseRealVectorSHT(nn.Module):
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment