Unverified Commit 77a64b2c authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco fused quadrature (#42)

* missed one instance of python3

* fused multiplication of quadrature into multiplication of the Psi tensor

* Added quadrature change to changelog

* removing persistent quad weights tensor as it isn't needed anymore

* added pretty-prenting for convolution modules

* adjusting default value in convolution test
parent ab0e66bc
...@@ -27,7 +27,7 @@ jobs: ...@@ -27,7 +27,7 @@ jobs:
python3 -m pip install setuptools wheel python3 -m pip install setuptools wheel
- name: Build a binary wheel and a source tarball - name: Build a binary wheel and a source tarball
run: | run: |
python setup.py sdist bdist_wheel python3 setup.py sdist bdist_wheel
# - name: Publish package to TestPyPI # - name: Publish package to TestPyPI
# uses: pypa/gh-action-pypi-publish@master # uses: pypa/gh-action-pypi-publish@master
# with: # with:
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
* CUDA-accelerated DISCO convolutions * CUDA-accelerated DISCO convolutions
* Updated DISCO convolutions to support even number of collocation points across the diameter * Updated DISCO convolutions to support even number of collocation points across the diameter
* Distributed DISCO convolutions * Distributed DISCO convolutions
* Fused quadrature into multiplication with the Psi tensor to lower memory footprint
* Removed DISCO convolution in the plane to focus on the sphere * Removed DISCO convolution in the plane to focus on the sphere
* Updated unit tests which now include tests for the distributed convolutions * Updated unit tests which now include tests for the distributed convolutions
......
...@@ -109,7 +109,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: ...@@ -109,7 +109,7 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
return vals return vals
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
""" """
Discretely normalizes the convolution tensor. Discretely normalizes the convolution tensor.
""" """
...@@ -121,13 +121,17 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati ...@@ -121,13 +121,17 @@ def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalizati
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation # look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor
if merge_quadrature:
psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi
else: else:
psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True) psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True)
if merge_quadrature:
psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
return psi / (psi_norm + eps) return psi / (psi_norm + eps)
def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False): def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False):
""" """
Helper routine to compute the convolution Tensor in a dense fashion Helper routine to compute the convolution Tensor in a dense fashion
""" """
...@@ -187,7 +191,7 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad ...@@ -187,7 +191,7 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad
out[:, t, p, :, :] = kernel_handle(theta, phi) out[:, t, p, :, :] = kernel_handle(theta, phi)
# take care of normalization # take care of normalization
out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization) out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature)
return out return out
...@@ -263,13 +267,13 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -263,13 +267,13 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in
if transpose: if transpose:
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True).to(self.device) psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense() psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else: else:
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False).to(self.device) psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
...@@ -296,9 +300,9 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -296,9 +300,9 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
x_ref.requires_grad = True x_ref.requires_grad = True
if transpose: if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref) y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights) y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref)
else: else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights) y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref) y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
y_ref.backward(grad_input) y_ref.backward(grad_input)
x_ref_grad = x_ref.grad.clone() x_ref_grad = x_ref.grad.clone()
......
...@@ -140,7 +140,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in ...@@ -140,7 +140,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in
return iidx, vals return iidx, vals
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, eps=1e-9): def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
""" """
Discretely normalizes the convolution tensor. Discretely normalizes the convolution tensor.
""" """
...@@ -167,6 +167,10 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker ...@@ -167,6 +167,10 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat)) iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output # normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx] * q[iidx]) vnorm = torch.sum(psi_vals[iidx] * q[iidx])
if merge_quadrature:
# the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps) psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
else: else:
# pre-compute the quadrature weights # pre-compute the quadrature weights
...@@ -179,13 +183,16 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker ...@@ -179,13 +183,16 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat)) iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
# normalize # normalize
vnorm = torch.sum(psi_vals[iidx] * q[iidx]) vnorm = torch.sum(psi_vals[iidx] * q[iidx])
if merge_quadrature:
psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
else:
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps) psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
return psi_vals return psi_vals
def _precompute_convolution_tensor_s2( def _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -269,7 +276,9 @@ def _precompute_convolution_tensor_s2( ...@@ -269,7 +276,9 @@ def _precompute_convolution_tensor_s2(
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization) out_vals = _normalize_convolution_tensor_s2(
out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
)
return out_idx, out_vals return out_idx, out_vals
...@@ -359,13 +368,8 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -359,13 +368,8 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)
idx, vals = _precompute_convolution_tensor_s2( idx, vals = _precompute_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
) )
# sort the values # sort the values
...@@ -381,6 +385,12 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -381,6 +385,12 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
def psi_idx(self): def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
...@@ -390,8 +400,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -390,8 +400,6 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
return psi return psi
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
x = _disco_s2_contraction_cuda( x = _disco_s2_contraction_cuda(
...@@ -408,7 +416,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -408,7 +416,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, K, H, W) x = x.reshape(B, self.groups, self.groupsize, K, H, W)
# do weight multiplication # do weight multiplication
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
out = out.reshape(B, -1, H, W) out = out.reshape(B, -1, H, W)
if self.bias is not None: if self.bias is not None:
...@@ -449,14 +457,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -449,14 +457,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)
# switch in_shape and out_shape since we want transpose conv # switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor_s2( idx, vals = _precompute_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
) )
# sort the values # sort the values
...@@ -472,6 +475,12 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -472,6 +475,12 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
def psi_idx(self): def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
...@@ -497,12 +506,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -497,12 +506,9 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication # do weight multiplication
x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous() x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
x = x.reshape(B, -1, x.shape[-3], H, W) x = x.reshape(B, -1, x.shape[-3], H, W)
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda( out = _disco_s2_transpose_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
......
...@@ -71,7 +71,7 @@ except ImportError as err: ...@@ -71,7 +71,7 @@ except ImportError as err:
def _precompute_distributed_convolution_tensor_s2( def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -156,7 +156,7 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -156,7 +156,7 @@ def _precompute_distributed_convolution_tensor_s2(
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization) out_vals = _normalize_convolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature)
# TODO: this part can be split off into it's own function # TODO: this part can be split off into it's own function
# split the latitude indices: # split the latitude indices:
...@@ -224,10 +224,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -224,10 +224,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / float(self.nlon_in)
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel # of atomic reduction calls inside the actual kernel
...@@ -236,13 +232,9 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -236,13 +232,9 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar] self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2( idx, vals = _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
) )
# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
# sort the values # sort the values
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
...@@ -256,6 +248,12 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -256,6 +248,12 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
def psi_idx(self): def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
...@@ -273,9 +271,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -273,9 +271,6 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
if x.is_cuda and _cuda_extension_available: if x.is_cuda and _cuda_extension_available:
x = _disco_s2_contraction_cuda( x = _disco_s2_contraction_cuda(
x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
...@@ -355,10 +350,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -355,10 +350,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if theta_cutoff <= 0.0: if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.") raise ValueError("Error, theta_cutoff has to be positive.")
# integration weights
_, wgl = _precompute_latitudes(self.nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
# Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution, # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
# we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
# of atomic reduction calls inside the actual kernel # of atomic reduction calls inside the actual kernel
...@@ -370,13 +361,9 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -370,13 +361,9 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv # switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose # distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2( idx, vals = _precompute_distributed_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
) )
# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
# sort the values # sort the values
ker_idx = idx[0, ...].contiguous() ker_idx = idx[0, ...].contiguous()
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
...@@ -390,6 +377,12 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -390,6 +377,12 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False) self.register_buffer("psi_vals", vals, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"
@property @property
def psi_idx(self): def psi_idx(self):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
...@@ -424,9 +417,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -424,9 +417,6 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
if self.comm_size_azimuth > 1: if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
# multiply weights
x = self.quad_weights * x
# gather input tensor and set up backward reduction hooks # gather input tensor and set up backward reduction hooks
x = gather_from_polar_region(x, -2, self.lat_in_shapes) x = gather_from_polar_region(x, -2, self.lat_in_shapes)
x = copy_to_polar_region(x) x = copy_to_polar_region(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment