Unverified Commit ad927429 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/discrete continuous convolutions (#24)



* fixing banding issues with corrected computation of kernel size
* fixing Parameters
* changed normalization of isotropic kernels

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
parent eab72f04
......@@ -57,7 +57,7 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ikernel * dtheta
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff))
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff))
......@@ -104,6 +104,7 @@ def _precompute_convolution_tensor(
# compute latitude of the rotated position
z = torch.cos(alpha) * torch.cos(gamma) - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma)
z = torch.clamp(z, min=-1.0, max=1.0)
theta = torch.arccos(z)
# compute cartesian coordinates of the rotated position
......@@ -160,7 +161,7 @@ class DiscreteContinuousConvS2(nn.Module):
# bandlimit
if theta_cutoff is None:
theta_cutoff = kernel_shape[0] * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -184,13 +185,13 @@ class DiscreteContinuousConvS2(nn.Module):
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
self.register_buffer("weight", weight)
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
if bias:
btens = nn.Parameter(torch.zeros(out_channels))
self.register_buffer("bias", btens)
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
......@@ -208,7 +209,8 @@ class DiscreteContinuousConvS2(nn.Module):
x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication
out = torch.einsum("bgckxy,fck->bfxy", x, self.weight)
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
......@@ -249,7 +251,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# bandlimit
if theta_cutoff is None:
theta_cutoff = kernel_shape[0] * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
......@@ -274,13 +276,13 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
self.register_buffer("weight", weight)
self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
if bias:
btens = nn.Parameter(torch.zeros(out_channels))
self.register_buffer("bias", btens)
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
......@@ -290,7 +292,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
x = torch.einsum("bgfxy,cfk->bckxy", x, self.weight)
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment