Unverified Commit 942aa4ea authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco refactor (#27)

* Moved convolutions and exposed them directly

* Added transposition to the unit test

* Minor bugfix in CPU version of DISCO transpose code

* Adding convolution tests to CI

* Added gradient check

* Checking the weight grad as well

* Added test for anisotropic kernels
parent 1e5f7a2f
......@@ -23,4 +23,4 @@ jobs:
- name: Test with pytest
run: |
python -m pip install pytest pytest-cov parameterized
python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py
\ No newline at end of file
python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py ./tests/test_convolution.py
\ No newline at end of file
......@@ -2,10 +2,19 @@
## Versioning
### v0.6.5
* Discrrete-continuous (DISCO) convolutions on the sphere
* Isotropic and anisotropic DISCO convolutions
* Accelerated DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions
### v0.6.4
* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* base pytorch container version bumped up to 23.11 in Dockerfile
* Reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* Distributed SHT tests are now using unittest. Test extended to vector SHT versions
* Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* Base pytorch container version bumped up to 23.11 in Dockerfile
### v0.6.3
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
from parameterized import parameterized
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta
norm_factor = (
2
* math.pi
* (
1
- math.cos(theta_cutoff - dtheta)
+ math.cos(theta_cutoff - dtheta)
+ (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta
)
)
vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff),
(1 - (theta - itheta).abs() / dtheta) / norm_factor,
0,
)
return vals
def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals)
return vals
def _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
for t in range(nlat_out):
for p in range(nlon_out):
alpha = -lats_out[t]
beta = lons_in - lons_out[p]
gamma = lats_in.reshape(-1, 1)
# compute latitude of the rotated position
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# compute cartesian coordinates of the rotated position
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
# normalize instead of clipping to ensure correct range
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi)
return out
class TestDiscreteContinuousConvolution(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.device = torch.device("cpu")
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5],
]
)
def test_disco_convolution(
self,
batch_size,
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
grid_in,
grid_out,
transpose,
tol,
):
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
).to(self.device)
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if transpose:
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
).to(self.device)
else:
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)
self.assertTrue(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)
# create a copy of the weight
w_ref = conv.weight.detach().clone()
w_ref.requires_grad_(True)
# create an input signal
torch.manual_seed(333)
x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device)
# perform the reference computation
x_ref = x.clone().detach()
x_ref.requires_grad_(True)
if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights)
else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
# use the convolution module
y = conv(x)
# compare results
self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))
# compute gradients and compare results
grad_input = torch.randn_like(y)
y_ref.backward(grad_input)
y.backward(grad_input)
# compare
self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
if __name__ == "__main__":
unittest.main()
......@@ -149,9 +149,9 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
input = torch.randn_like(signal, requires_grad=True)
grad_input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol)
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result)
......
......@@ -32,8 +32,7 @@
__version__ = '0.6.4'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from . import quadrature
from . import s2_convolutions
from . import disco_convolutions
from . import random_fields
from . import examples
......@@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
assert psi.shape[-1] == nlat_in * nlon_in
assert nlon_in % nlon_out == 0
assert nlon_in >= nlat_out
pscale = nlon_in // nlon_out
# add a dummy dimension for nkernel and move the batch and channel dims to the end
......@@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
assert psi.shape[-2] == nlat_in
assert n_out % nlon_out == 0
nlat_out = n_out // nlon_out
assert nlon_out >= nlat_in
pscale = nlon_out // nlon_in
# we do a semi-transposition to faciliate the computation
......@@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
x_ext[:, :, (pscale-1)::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
# we need to go backwards through the vector, so we flip the axis
x_ext = x_ext.contiguous()
......
......@@ -39,7 +39,7 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.disco_convolutions import (
from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch,
_disco_s2_contraction_triton,
......@@ -47,14 +47,14 @@ from torch_harmonics.disco_convolutions import (
)
def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float):
def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / kernel_size
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
......@@ -64,6 +64,29 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
return iidx, vals
def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_theta & cond_phi)
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0)
return iidx, vals
def _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
......@@ -88,7 +111,9 @@ def _precompute_convolution_tensor(
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
......@@ -128,9 +153,9 @@ def _precompute_convolution_tensor(
y = y / norm
z = z / norm
# compute spherical coordinates
# compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z)
phi = torch.arctan2(y, x)
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi)
......@@ -146,8 +171,7 @@ def _precompute_convolution_tensor(
# TODO:
# - parameter initialization
# - add anisotropy
# - derive conv and conv transpose from single module
class DiscreteContinuousConvS2(nn.Module):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
......@@ -175,9 +199,13 @@ class DiscreteContinuousConvS2(nn.Module):
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
self.kernel_size = 1
for kdim in kernel_shape:
self.kernel_size *= kdim
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
......@@ -209,7 +237,7 @@ class DiscreteContinuousConvS2(nn.Module):
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
......@@ -266,9 +294,12 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
self.kernel_size = 1
for kdim in kernel_shape:
self.kernel_size *= kdim
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# bandlimit
if theta_cutoff is None:
......@@ -301,7 +332,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0]))
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
......@@ -310,7 +341,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape
B, F, H, W = x.shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment