Unverified Commit b3816ebc authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Setting imaginary parts of DCT and Nyquist frequencies to 0 in IRFFT (#72)

* setting imaginary parts of DCT and nyquist frequency to zero in IRSHT variants
parent dca116b5
......@@ -50,8 +50,6 @@ from torch_harmonics import RealSHT
# wandb logging
import wandb
wandb.login()
def l2loss_sphere(solver, prd, tar, relative=False, squared=True):
loss = solver.integrate_grid((prd - tar) ** 2, dimensionless=True).sum(dim=-1)
......@@ -265,7 +263,7 @@ def log_weights_and_grads(model, iters=1):
"""
Helper routine intended for debugging purposes
"""
root_path = os.path.join(os.path.dirname(__file__), "weights_and_grads")
root_path = os.path.join(os.getcwd(), "weights_and_grads")
weights_and_grads_fname = os.path.join(root_path, f"weights_and_grads_step{iters:03d}.tar")
print(weights_and_grads_fname)
......@@ -381,6 +379,9 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.manual_seed(333)
torch.cuda.manual_seed(333)
# login
wandb.login()
# set parameters
nfuture=0
......@@ -394,7 +395,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt_solver = 150
nsteps = dt // dt_solver
grid = "legendre-gauss"
nlat, nlon =(181, 360)
nlat, nlon = (257, 512)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
......@@ -441,8 +442,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=[2, 2],
encoder_kernel_shape=[2, 2],
kernel_shape=(2, 2),
encoder_kernel_shape=(2, 2),
filter_basis_type="morlet",
upsample_sht = True,
)
......@@ -459,14 +460,14 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
pos_embed=False,
use_mlp=True,
normalization_layer="none",
kernel_shape=[4],
encoder_kernel_shape=[4],
kernel_shape=(4),
encoder_kernel_shape=(4),
filter_basis_type="zernike",
upsample_sht = True,
)
# iterate over models and train each model
root_path = os.path.dirname(__file__)
root_path = os.getcwd()
for model_name, model_handle in models.items():
model = model_handle().to(device)
......@@ -498,7 +499,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
start_time = time.time()
print(f"Training {model_name}, single step")
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=1, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
train_model(model, dataloader, optimizer, gscaler, scheduler, nepochs=200, loss_fn="l2", enable_amp=enable_amp, log_grads=log_grads)
if nfuture > 0:
print(f'Training {model_name}, {nfuture} step')
......
......@@ -248,9 +248,6 @@ class DistributedInverseRealSHT(nn.Module):
# einsum
xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
#rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
#im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
#xs = torch.stack((rl, im), -1).contiguous()
# inverse FFT
x = torch.view_as_complex(xs)
......@@ -263,6 +260,11 @@ class DistributedInverseRealSHT(nn.Module):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
# set DCT and nyquist frequencies to 0:
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -528,6 +530,11 @@ class DistributedInverseRealVectorSHT(nn.Module):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
# set DCT and nyquist frequencies to zero
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
......@@ -33,7 +33,7 @@
import torch
import torch.nn as nn
import torch_harmonics as harmonics
from torch_harmonics.quadrature import *
from torch_harmonics.quadrature import _precompute_longitudes
import math
import numpy as np
......
......@@ -195,13 +195,18 @@ class InverseRealSHT(nn.Module):
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1)
xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
......@@ -395,6 +400,14 @@ class InverseRealVectorSHT(nn.Module):
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment