Unverified Commit ad927429 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/discrete continuous convolutions (#24)



* fixing banding issues with corrected computation of kernel size
* fixing Parameters
* changed normalization of isotropic kernels

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
parent eab72f04
...@@ -57,7 +57,7 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern ...@@ -57,7 +57,7 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ikernel * dtheta itheta = ikernel * dtheta
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff)) norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)) iidx = torch.argwhere(((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff))
...@@ -104,6 +104,7 @@ def _precompute_convolution_tensor( ...@@ -104,6 +104,7 @@ def _precompute_convolution_tensor(
# compute latitude of the rotated position # compute latitude of the rotated position
z = torch.cos(alpha) * torch.cos(gamma) - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) z = torch.cos(alpha) * torch.cos(gamma) - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma)
z = torch.clamp(z, min=-1.0, max=1.0)
theta = torch.arccos(z) theta = torch.arccos(z)
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
...@@ -160,7 +161,7 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -160,7 +161,7 @@ class DiscreteContinuousConvS2(nn.Module):
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = kernel_shape[0] * torch.pi / float(self.nlat_in - 1) theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -184,13 +185,13 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -184,13 +185,13 @@ class DiscreteContinuousConvS2(nn.Module):
# weight tensor # weight tensor
if in_channels % self.groups != 0: if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0])) self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
self.register_buffer("weight", weight)
if bias: if bias:
btens = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
self.register_buffer("bias", btens)
else: else:
self.bias = None self.bias = None
...@@ -208,7 +209,8 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -208,7 +209,8 @@ class DiscreteContinuousConvS2(nn.Module):
x = x.reshape(B, self.groups, self.groupsize, K, H, W) x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication # do weight multiplication
out = torch.einsum("bgckxy,fck->bfxy", x, self.weight) out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1])
if self.bias is not None: if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1) out = out + self.bias.reshape(1, -1, 1, 1)
...@@ -249,7 +251,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -249,7 +251,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = kernel_shape[0] * torch.pi / float(self.nlat_in - 1) theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
...@@ -274,13 +276,13 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -274,13 +276,13 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# weight tensor # weight tensor
if in_channels % self.groups != 0: if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size") raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0])) self.weight = nn.Parameter(torch.ones(out_channels, self.groupsize, kernel_shape[0]))
self.register_buffer("weight", weight)
if bias: if bias:
btens = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
self.register_buffer("bias", btens)
else: else:
self.bias = None self.bias = None
...@@ -290,7 +292,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -290,7 +292,8 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication # do weight multiplication
x = torch.einsum("bgfxy,cfk->bckxy", x, self.weight) x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
x = x.reshape(x.shape[0], -1, x.shape[-3], x.shape[-2], x.shape[-1])
# pre-multiply x with the quadrature weights # pre-multiply x with the quadrature weights
x = self.quad_weights * x x = self.quad_weights * x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment