Unverified Commit dca116b5 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Revert "setting imaginary parts of DCT and nyquist frequency to zero in IRSHT…" (#71)

This reverts commit 82881276.
parent 82881276
......@@ -248,6 +248,9 @@ class DistributedInverseRealSHT(nn.Module):
# einsum
xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
#rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
#im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
#xs = torch.stack((rl, im), -1).contiguous()
# inverse FFT
x = torch.view_as_complex(xs)
......@@ -260,11 +263,6 @@ class DistributedInverseRealSHT(nn.Module):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
# set DCT and nyquist frequencies to 0:
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -530,11 +528,6 @@ class DistributedInverseRealVectorSHT(nn.Module):
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
# set DCT and nyquist frequencies to zero
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
x[..., self.nlon // 2].imag = 0.0
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
......@@ -195,18 +195,13 @@ class InverseRealSHT(nn.Module):
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)
xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
xs = torch.stack((rl, im), -1)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
......@@ -400,14 +395,6 @@ class InverseRealVectorSHT(nn.Module):
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
# ensure that imaginary part of 0 and nyquist components are zero
# this is important because not all backend algorithms provided through the
# irfft interface ensure that
x[..., 0].imag = 0.0
if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
x[..., self.nlon // 2].imag = 0.0
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment