Unverified Commit 82881276 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

setting imaginary parts of DCT and nyquist frequency to zero in IRSHT (#70)

* setting imaginary parts of DCT and nyquist frequency to zero in IRSHT variants

* small fix

* making einsum result contiguous

* adding zero frequency to distributed sht
parent 39a0e375
...@@ -248,9 +248,6 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -248,9 +248,6 @@ class DistributedInverseRealSHT(nn.Module):
# einsum # einsum
xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous() xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
#rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
#im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
#xs = torch.stack((rl, im), -1).contiguous()
# inverse FFT # inverse FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
...@@ -263,6 +260,11 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -263,6 +260,11 @@ class DistributedInverseRealSHT(nn.Module):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes) x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
# set DCT and nyquist frequencies to 0:
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
...@@ -528,6 +530,11 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -528,6 +530,11 @@ class DistributedInverseRealVectorSHT(nn.Module):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes) x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
# set DCT and nyquist frequencies to zero
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
...@@ -195,13 +195,18 @@ class InverseRealSHT(nn.Module): ...@@ -195,13 +195,18 @@ class InverseRealSHT(nn.Module):
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x) x = torch.view_as_real(x)
xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x return x
...@@ -395,6 +400,14 @@ class InverseRealVectorSHT(nn.Module): ...@@ -395,6 +400,14 @@ class InverseRealVectorSHT(nn.Module):
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment