Unverified Commit 942aa4ea authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco refactor (#27)

* Moved convolutions and exposed them directly

* Added transposition to the unit test

* Minor bugfix in CPU version of DISCO transpose code

* Adding convolution tests to CI

* Added gradient check

* Checking the weight grad as well

* Added test for anisotropic kernels
parent 1e5f7a2f
...@@ -23,4 +23,4 @@ jobs: ...@@ -23,4 +23,4 @@ jobs:
- name: Test with pytest - name: Test with pytest
run: | run: |
python -m pip install pytest pytest-cov parameterized python -m pip install pytest pytest-cov parameterized
python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py ./tests/test_convolution.py
\ No newline at end of file \ No newline at end of file
...@@ -2,10 +2,19 @@ ...@@ -2,10 +2,19 @@
## Versioning ## Versioning
### v0.6.5
* Discrrete-continuous (DISCO) convolutions on the sphere
* Isotropic and anisotropic DISCO convolutions
* Accelerated DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions
### v0.6.4 ### v0.6.4
* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py` * Reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* base pytorch container version bumped up to 23.11 in Dockerfile * Distributed SHT tests are now using unittest. Test extended to vector SHT versions
* Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* Base pytorch container version bumped up to 23.11 in Dockerfile
### v0.6.3 ### v0.6.3
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
from parameterized import parameterized
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta
norm_factor = (
2
* math.pi
* (
1
- math.cos(theta_cutoff - dtheta)
+ math.cos(theta_cutoff - dtheta)
+ (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta
)
)
vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff),
(1 - (theta - itheta).abs() / dtheta) / norm_factor,
0,
)
return vals
def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals)
return vals
def _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
assert len(in_shape) == 2
assert len(out_shape) == 2
if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices
# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)
for t in range(nlat_out):
for p in range(nlon_out):
alpha = -lats_out[t]
beta = lons_in - lons_out[p]
gamma = lats_in.reshape(-1, 1)
# compute latitude of the rotated position
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# compute cartesian coordinates of the rotated position
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)
# normalize instead of clipping to ensure correct range
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi)
return out
class TestDiscreteContinuousConvolution(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.device = torch.device("cpu")
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5],
]
)
def test_disco_convolution(
self,
batch_size,
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
grid_in,
grid_out,
transpose,
tol,
):
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
).to(self.device)
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if transpose:
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
).to(self.device)
else:
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)
self.assertTrue(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)
# create a copy of the weight
w_ref = conv.weight.detach().clone()
w_ref.requires_grad_(True)
# create an input signal
torch.manual_seed(333)
x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device)
# perform the reference computation
x_ref = x.clone().detach()
x_ref.requires_grad_(True)
if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights)
else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
# use the convolution module
y = conv(x)
# compare results
self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))
# compute gradients and compare results
grad_input = torch.randn_like(y)
y_ref.backward(grad_input)
y.backward(grad_input)
# compare
self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
if __name__ == "__main__":
unittest.main()
...@@ -149,9 +149,9 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -149,9 +149,9 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs) signal = isht(coeffs)
input = torch.randn_like(signal, requires_grad=True) grad_input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol) test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result) self.assertTrue(test_result)
......
...@@ -32,8 +32,7 @@ ...@@ -32,8 +32,7 @@
__version__ = '0.6.4' __version__ = '0.6.4'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from . import quadrature from . import quadrature
from . import s2_convolutions
from . import disco_convolutions
from . import random_fields from . import random_fields
from . import examples from . import examples
...@@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
assert psi.shape[-1] == nlat_in * nlon_in assert psi.shape[-1] == nlat_in * nlon_in
assert nlon_in % nlon_out == 0 assert nlon_in % nlon_out == 0
assert nlon_in >= nlat_out
pscale = nlon_in // nlon_out pscale = nlon_in // nlon_out
# add a dummy dimension for nkernel and move the batch and channel dims to the end # add a dummy dimension for nkernel and move the batch and channel dims to the end
...@@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl ...@@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
assert psi.shape[-2] == nlat_in assert psi.shape[-2] == nlat_in
assert n_out % nlon_out == 0 assert n_out % nlon_out == 0
nlat_out = n_out // nlon_out nlat_out = n_out // nlon_out
assert nlon_out >= nlat_in
pscale = nlon_out // nlon_in pscale = nlon_out // nlon_in
# we do a semi-transposition to faciliate the computation # we do a semi-transposition to faciliate the computation
...@@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl ...@@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
x_ext[:, :, (pscale-1)::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
# we need to go backwards through the vector, so we flip the axis # we need to go backwards through the vector, so we flip the axis
x_ext = x_ext.contiguous() x_ext = x_ext.contiguous()
......
...@@ -39,7 +39,7 @@ import torch.nn as nn ...@@ -39,7 +39,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.disco_convolutions import ( from torch_harmonics._disco_convolution import (
_disco_s2_contraction_torch, _disco_s2_contraction_torch,
_disco_s2_transpose_contraction_torch, _disco_s2_transpose_contraction_torch,
_disco_s2_contraction_triton, _disco_s2_contraction_triton,
...@@ -47,14 +47,14 @@ from torch_harmonics.disco_convolutions import ( ...@@ -47,14 +47,14 @@ from torch_harmonics.disco_convolutions import (
) )
def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float): def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
""" """
Computes the index set that falls into the isotropic kernel's support and returns both indices and values. Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
""" """
# compute the support # compute the support
dtheta = (theta_cutoff - 0.0) / kernel_size dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta itheta = ikernel * dtheta
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
...@@ -64,6 +64,29 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern ...@@ -64,6 +64,29 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
return iidx, vals return iidx, vals
def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_theta & cond_phi)
vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor
vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0)
return iidx, vals
def _precompute_convolution_tensor( def _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
...@@ -88,7 +111,9 @@ def _precompute_convolution_tensor( ...@@ -88,7 +111,9 @@ def _precompute_convolution_tensor(
assert len(out_shape) == 2 assert len(out_shape) == 2
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff) kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
else: else:
raise ValueError("kernel_shape should be either one- or two-dimensional.") raise ValueError("kernel_shape should be either one- or two-dimensional.")
...@@ -128,9 +153,9 @@ def _precompute_convolution_tensor( ...@@ -128,9 +153,9 @@ def _precompute_convolution_tensor(
y = y / norm y = y / norm
z = z / norm z = z / norm
# compute spherical coordinates # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
theta = torch.arccos(z) theta = torch.arccos(z)
phi = torch.arctan2(y, x) phi = torch.arctan2(y, x) + torch.pi
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
iidx, vals = kernel_handle(theta, phi) iidx, vals = kernel_handle(theta, phi)
...@@ -146,8 +171,7 @@ def _precompute_convolution_tensor( ...@@ -146,8 +171,7 @@ def _precompute_convolution_tensor(
# TODO: # TODO:
# - parameter initialization # - derive conv and conv transpose from single module
# - add anisotropy
class DiscreteContinuousConvS2(nn.Module): class DiscreteContinuousConvS2(nn.Module):
""" """
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
...@@ -175,9 +199,13 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -175,9 +199,13 @@ class DiscreteContinuousConvS2(nn.Module):
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape] kernel_shape = [kernel_shape]
self.kernel_size = 1 if len(kernel_shape) == 1:
for kdim in kernel_shape: self.kernel_size = kernel_shape[0]
self.kernel_size *= kdim elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# compute theta cutoff based on the bandlimit of the input field # compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None: if theta_cutoff is None:
...@@ -209,7 +237,7 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -209,7 +237,7 @@ class DiscreteContinuousConvS2(nn.Module):
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize) scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
...@@ -266,9 +294,12 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -266,9 +294,12 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape] kernel_shape = [kernel_shape]
self.kernel_size = 1 if len(kernel_shape) == 1:
for kdim in kernel_shape: self.kernel_size = kernel_shape[0]
self.kernel_size *= kdim elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")
# bandlimit # bandlimit
if theta_cutoff is None: if theta_cutoff is None:
...@@ -301,7 +332,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -301,7 +332,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize) scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(out_channels)) self.bias = nn.Parameter(torch.zeros(out_channels))
...@@ -310,7 +341,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module): ...@@ -310,7 +341,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor:
# extract shape # extract shape
B, F, H, W = x.shape B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W) x = x.reshape(B, self.groups, self.groupsize, H, W)
# do weight multiplication # do weight multiplication
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment