Unverified Commit 29e7fb68 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Tkurth/cuda disco (#38)



* adding cuda kernels for disco conv

* making psi_idx an attribute

* adding license headers

* adding author files

* reorganizing files

* draft implementation

* added conditional installation to setup.py

* formatting changes

* removing triton kernel in DISCO convolution

* updated github actions

* updated Readme and changelog

* adding another guard for the cuda installation

* renaming the  cuda extension

* simplifying setup.py

* minor bugfix

* Bbonev/cuda disco cleanup (#32)

* cleanup of disco convolutions based on CUDA extension

* fixing unittest

* changing version to experimental 0.7.0a

* initial rewrite of the distributed convolution with CUDA

* fixing streams

* need to fix install options

* fixing streams

* undid setup.py changes

* reset setup.py

* including CUDAStream

* adjusted the precomputation of theta_cutoff. If you rely on this, your models will not be backwards-compatible.

* adjusting theta_cutoff in the unittest

* adding newly refactored kernels for faster compile

* Tkurth/cuda disco distributed fix (#34)

* attempt to make disco distributed

* working distributed convolutions

* fixing distributed conv

* working distributed disco

* removing irrelevant extra argument

* using stream functions from at instead of c10

* using stream functions from at instead of c10, small fix

* Bbonev/disc even filters (#35)

* initial working commit with new convention of counting collocation points across the diameter instead of across the radius

* fixed a bug in the computation of the even kernels

* changing heuristic for computing theta_cutoff

* Fixing unittest

* Readability improvements

* reworked normalization of filter basis functions

* implemented discrete normalization of disco filters

* relaxing tolerances in convolution unit test

* bugfix to correctly support unequal scale factors in latitudes and longitudes

* hotfix to a bug in the imports

* Bbonev/distributed disco refactor (#37)

* cleaned up normalization code in convolution

* formatting changes in distributed convolution

* Fixing default theta_cutoff to be the same in distributed and local case

* fixed distributed convolution to support the same normalization as non-distributed one

* readability improvements

* fixed initial scale of convolution parameter weights and fixed naming of the normalization routine

* Updated Readme.md

* added comment in Dockerfile regarding older architectures

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
Co-authored-by: default avatarBoris Bonev <bbonev@nvidia.com>
parent 214fa40a
...@@ -11,11 +11,11 @@ jobs: ...@@ -11,11 +11,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python 3.9 - name: Set up Python 3.9
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.9 python-version: 'pypy3.9'
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
......
...@@ -8,11 +8,11 @@ jobs: ...@@ -8,11 +8,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python 3.9 - name: Set up Python 3.9
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.9 python-version: '3.10'
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip setuptools wheel python -m pip install --upgrade pip setuptools wheel
......
...@@ -2,6 +2,8 @@ The code was authored by the following people: ...@@ -2,6 +2,8 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation Boris Bonev - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation Thorsten Kurth - NVIDIA Corporation
Mauro Bisson - NVIDIA Corporation
Massimiliano Fatica - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation Christian Hundt - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
Jean Kossaifi - NVIDIA Corporation Jean Kossaifi - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
\ No newline at end of file
...@@ -2,6 +2,14 @@ ...@@ -2,6 +2,14 @@
## Versioning ## Versioning
### v0.7.0
* CUDA-accelerated DISCO convolutions
* Updated DISCO convolutions to support even number of collocation points across the diameter
* Distributed DISCO convolutions
* Removed DISCO convolution in the plane to focus on the sphere
* Updated unit tests which now include tests for the distributed convolutions
### v0.6.5 ### v0.6.5
* Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions * Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
......
...@@ -34,6 +34,10 @@ FROM nvcr.io/nvidia/pytorch:23.11-py3 ...@@ -34,6 +34,10 @@ FROM nvcr.io/nvidia/pytorch:23.11-py3
COPY . /workspace/torch_harmonics COPY . /workspace/torch_harmonics
# we need this for tests
RUN pip install parameterized RUN pip install parameterized
RUN pip install /workspace/torch_harmonics
# The custom CUDA extension does not suppport architerctures < 7.0
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
RUN pip install --global-option --cuda_ext /workspace/torch_harmonics
...@@ -50,7 +50,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -50,7 +50,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## Overview ## Overview
torch-harmonics is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. torch-harmonics implements differentiable signal processing on the sphere. This includes differentiable implementations of the spherical harmonic transforms, vector spherical harmonic transforms and discrete-continuous convolutions on the sphere. The package was originally implemented to enable Spherical Fourier Neural Operators (SFNO) [1].
The SHT algorithm uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes [2].
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
...@@ -73,13 +75,18 @@ torch-harmonics has been used to implement a variety of differentiable PDE solve ...@@ -73,13 +75,18 @@ torch-harmonics has been used to implement a variety of differentiable PDE solve
## Installation ## Installation
Download directyly from PyPI: Download directly from PyPI:
```bash ```bash
pip install torch-harmonics pip install torch-harmonics
``` ```
If you would like to have accelerated CUDA extensions for the discrete-continuous convolutions, please use the '--cuda_ext' flag:
```bash
pip install --global-option --cuda_ext torch-harmonics
```
:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0.
Build in your environment using the Python package: If you want to actively develop torch-harmonics, we recommend building it in your environment from github:
```bash ```bash
git clone git@github.com:NVIDIA/torch-harmonics.git git clone git@github.com:NVIDIA/torch-harmonics.git
...@@ -160,6 +167,10 @@ $$ ...@@ -160,6 +167,10 @@ $$
Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$. Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.
### Discrete-continuous convolutions
torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere.
## Getting started ## Getting started
The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by: The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:
...@@ -223,7 +234,7 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins ...@@ -223,7 +234,7 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins
## Contributors ## Contributors
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Christian Hundt](https://github.com/gravitino) (chundt@nvidia.com), [Nikola Kovachki](https://kovachki.github.io) (nkovachki@nvidia.com), [Jean Kossaifi](http://jeankossaifi.com) (jkossaifi@nvidia.com) [Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)
## Cite us ## Cite us
...@@ -256,3 +267,6 @@ G3: Geochemistry, Geophysics, Geosystems, 2013. ...@@ -256,3 +267,6 @@ G3: Geochemistry, Geophysics, Geosystems, 2013.
Wang B., Wang L., Xie Z.; Wang B., Wang L., Xie Z.;
Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids;
Adv Comput Math, 2018. Adv Comput Math, 2018.
<a id="1">[4]</a>
Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import sys
try: try:
from setuptools import setup, find_packages from setuptools import setup, find_packages
except ImportError: except ImportError:
...@@ -36,6 +39,8 @@ except ImportError: ...@@ -36,6 +39,8 @@ except ImportError:
import re import re
from pathlib import Path from pathlib import Path
import torch
from torch.utils import cpp_extension
def version(root_path): def version(root_path):
"""Returns the version taken from __init__.py """Returns the version taken from __init__.py
...@@ -49,11 +54,10 @@ def version(root_path): ...@@ -49,11 +54,10 @@ def version(root_path):
--------- ---------
https://packaging.python.org/guides/single-sourcing-package-version/ https://packaging.python.org/guides/single-sourcing-package-version/
""" """
version_path = root_path.joinpath('torch_harmonics', '__init__.py') version_path = root_path.joinpath("torch_harmonics", "__init__.py")
with version_path.open() as f: with version_path.open() as f:
version_file = f.read() version_file = f.read()
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
version_file, re.M)
if version_match: if version_match:
return version_match.group(1) return version_match.group(1)
raise RuntimeError("Unable to find version string.") raise RuntimeError("Unable to find version string.")
...@@ -67,37 +71,64 @@ def readme(root_path): ...@@ -67,37 +71,64 @@ def readme(root_path):
root_path : pathlib.Path root_path : pathlib.Path
path to the root of the package path to the root of the package
""" """
with root_path.joinpath('README.md').open(encoding='UTF-8') as f: with root_path.joinpath("README.md").open(encoding="UTF-8") as f:
return f.read() return f.read()
def get_ext_modules(argv):
compile_cuda_extension = False
if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
compile_cuda_extension = True
ext_modules = [
cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]),
]
if torch.cuda.is_available() or compile_cuda_extension:
ext_modules.append(
cpp_extension.CUDAExtension(
"disco_cuda_extension",
[
"torch_harmonics/csrc/disco/disco_interface.cu",
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
],
)
)
return ext_modules
root_path = Path(__file__).parent root_path = Path(__file__).parent
README = readme(root_path) README = readme(root_path)
VERSION = version(root_path) VERSION = version(root_path)
# external modules
ext_modules = get_ext_modules(sys.argv)
config = { config = {
'name': 'torch_harmonics', "name": "torch_harmonics",
'packages': find_packages(), "packages": find_packages(),
'description': 'A differentiable spherical harmonic transform for PyTorch.', "description": "A differentiable spherical harmonic transform for PyTorch.",
'long_description': README, "long_description": README,
'long_description_content_type' : 'text/markdown', "long_description_content_type": "text/markdown",
'url' : 'https://github.com/NVIDIA/torch-harmonics', "url": "https://github.com/NVIDIA/torch-harmonics",
'author': 'Boris Bonev', "author": "Boris Bonev",
'author_email': 'bbonev@nvidia.com', "author_email": "bbonev@nvidia.com",
'version': VERSION, "version": VERSION,
'install_requires': ['torch', 'numpy', 'triton'], "install_requires": ["torch", "numpy"],
'extras_require': { "extras_require": {
'sfno': ['tensorly', 'tensorly-torch'], "sfno": ["tensorly", "tensorly-torch"],
}, },
'license': 'Modified BSD', "license": "Modified BSD",
'scripts': [], "scripts": [],
'include_package_data': True, "include_package_data": True,
'classifiers': [ "classifiers": ["Topic :: Scientific/Engineering", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3"],
'Topic :: Scientific/Engineering', "ext_modules": ext_modules,
'License :: OSI Approved :: BSD License', "cmdclass": {"build_ext": cpp_extension.BuildExtension} if ext_modules else {},
'Programming Language :: Python :: 3'
],
} }
setup(**config) setup(**config)
...@@ -38,61 +38,96 @@ import torch ...@@ -38,61 +38,96 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
""" """
helper routine to compute the values of the isotropic kernel densely helper routine to compute the values of the isotropic kernel densely
""" """
kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)
# compute the support # compute the support
dtheta = (theta_cutoff - 0.0) / ntheta if nr % 2 == 1:
ikernel = torch.arange(ntheta).reshape(-1, 1, 1) ir = ikernel * dr
itheta = ikernel * dtheta else:
ir = (ikernel + 0.5) * dr
norm_factor = (
2
* math.pi
* (
1
- math.cos(theta_cutoff - dtheta)
+ math.cos(theta_cutoff - dtheta)
+ (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta
)
)
vals = torch.where( vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff), ((r - ir).abs() <= dr) & (r <= r_cutoff),
(1 - (theta - itheta).abs() / dtheta) / norm_factor, (1 - (r - ir).abs() / dr),
0, 0,
) )
return vals return vals
def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
""" """
helper routine to compute the values of the anisotropic kernel densely helper routine to compute the values of the anisotropic kernel densely
""" """
# compute the support kernel_size = (nr // 2) * nphi + nr % 2
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta dr = 2 * r_cutoff / (nr + 1)
iphi = ((ikernel - 1) % nphi) * dphi dphi = 2.0 * math.pi / nphi
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) # disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi
# compute the value of the filter
if nr % 2 == 1:
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0) r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) , 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0) phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals) vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = r_vals * phi_vals
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
vals += rn_vals * phin_vals
return vals return vals
def _precompute_convolution_tensor_dense( def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, eps=1e-9):
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi """
): Discretely normalizes the convolution tensor.
"""
kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
scale_factor = float(nlon_in // nlon_out)
if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor
else:
psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True)
return psi / (psi_norm + eps)
def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False):
""" """
Helper routine to compute the convolution Tensor in a dense fashion Helper routine to compute the convolution Tensor in a dense fashion
""" """
...@@ -100,12 +135,14 @@ def _precompute_convolution_tensor_dense( ...@@ -100,12 +135,14 @@ def _precompute_convolution_tensor_dense(
assert len(in_shape) == 2 assert len(in_shape) == 2
assert len(out_shape) == 2 assert len(out_shape) == 2
quad_weights = quad_weights.reshape(-1, 1)
if len(kernel_shape) == 1: if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = kernel_shape[0] kernel_size = math.ceil( kernel_shape[0] / 2)
elif len(kernel_shape) == 2: elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
else: else:
raise ValueError("kernel_shape should be either one- or two-dimensional.") raise ValueError("kernel_shape should be either one- or two-dimensional.")
...@@ -149,6 +186,9 @@ def _precompute_convolution_tensor_dense( ...@@ -149,6 +186,9 @@ def _precompute_convolution_tensor_dense(
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi) out[:, t, p, :, :] = kernel_handle(theta, phi)
# take care of normalization
out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization)
return out return out
...@@ -160,27 +200,31 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -160,27 +200,31 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333)
torch.manual_seed(333)
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # regular convolution
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), ( 8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), ( 6, 12), [4 ], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "legendre-gauss", False, 5e-5], [8, 4, 2, (16, 24), (8, 8), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "equiangular", False, 5e-5], [8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "legendre-gauss", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 1e-4],
# transpose convolution # transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 6, 12), (18, 36), [4 ], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "legendre-gauss", True, 5e-5], [8, 4, 2, (8, 8), (16, 24), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "equiangular", True, 5e-5], [8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "legendre-gauss", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 1e-4],
] ]
) )
def test_disco_convolution( def test_disco_convolution(
...@@ -196,6 +240,11 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -196,6 +240,11 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
transpose, transpose,
tol, tol,
): ):
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv( conv = Conv(
in_channels, in_channels,
...@@ -207,27 +256,24 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -207,27 +256,24 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
grid_in=grid_in, grid_in=grid_in,
grid_out=grid_out, grid_out=grid_out,
bias=False, bias=False,
theta_cutoff=theta_cutoff
).to(self.device) ).to(self.device)
nlat_in, nlon_in = in_shape _, wgl = _precompute_latitudes(nlat_in, grid=grid_in)
nlat_out, nlon_out = out_shape quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if transpose: if transpose:
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True).to(self.device)
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
).to(self.device) psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else: else:
psi_dense = _precompute_convolution_tensor_dense( psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False).to(self.device)
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
self.assertTrue( self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))
torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)
# create a copy of the weight # create a copy of the weight
w_ref = torch.empty_like(conv.weight) w_ref = torch.empty_like(conv.weight)
...@@ -264,5 +310,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -264,5 +310,6 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -46,11 +46,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -46,11 +46,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
# set up distributed # set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0)) cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1)) cls.grid_size_h = int(os.getenv("GRID_H", 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1)) cls.grid_size_w = int(os.getenv("GRID_W", 1))
port = int(os.getenv('MASTER_PORT', '29501')) port = int(os.getenv("MASTER_PORT", "29501"))
master_address = os.getenv('MASTER_ADDR', 'localhost') master_address = os.getenv("MASTER_ADDR", "localhost")
cls.world_size = cls.grid_size_h * cls.grid_size_w cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -60,24 +60,21 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -60,24 +60,21 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
cls.device = torch.device(f"cuda:{local_rank}") cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
proc_backend = 'nccl' proc_backend = "nccl"
else: else:
if cls.world_rank == 0: if cls.world_rank == 0:
print("Running test on CPU") print("Running test on CPU")
cls.device = torch.device('cpu') cls.device = torch.device("cpu")
proc_backend = 'gloo' proc_backend = "gloo"
torch.manual_seed(333) torch.manual_seed(333)
dist.init_process_group(backend = proc_backend, dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups: # now set up the comm groups:
#set default # set default
cls.w_group = None cls.w_group = None
cls.h_group = None cls.h_group = None
...@@ -109,14 +106,12 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -109,14 +106,12 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
if cls.world_rank in grp: if cls.world_rank in grp:
cls.h_group = tmp_group cls.h_group = tmp_group
if cls.world_rank == 0: if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}") print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
# initializing sht # initializing sht
thd.init(cls.h_group, cls.w_group) thd.init(cls.h_group, cls.w_group)
def _split_helper(self, tensor): def _split_helper(self, tensor):
with torch.no_grad(): with torch.no_grad():
# split in W # split in W
...@@ -129,14 +124,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -129,14 +124,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_local return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist): def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_out_shapes lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes lon_shapes = convolution_dist.lon_out_shapes
#print("tensor before gather shape", tensor.shape)
# gather in W # gather in W
if self.grid_size_w > 1: if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
...@@ -147,8 +139,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -147,8 +139,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
else: else:
tensor_gather = tensor tensor_gather = tensor
#print("tensor_gather shape", tensor_gather.shape)
# gather in H # gather in H
if self.grid_size_h > 1: if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes] gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
...@@ -159,7 +149,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -159,7 +149,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, convolution_dist): def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes # we need the shapes
lat_shapes = convolution_dist.lat_in_shapes lat_shapes = convolution_dist.lat_in_shapes
...@@ -185,31 +174,37 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -185,31 +174,37 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
return tensor_gather return tensor_gather
@parameterized.expand(
@parameterized.expand([ [
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6], [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", False, 1e-6], [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", False, 1e-6], [128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6], [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[ 64, 128, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", True, 1e-6], [128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", True, 1e-6], ]
]) )
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol):
kernel_shape, groups, grid_in, grid_out, transpose, tol):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
disco_args = dict(in_channels=C, out_channels=C, disco_args = dict(
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out), in_channels=C,
kernel_shape=kernel_shape, groups=groups, out_channels=C,
grid_in=grid_in, grid_out=grid_out, bias=True) in_shape=(nlat_in, nlon_in),
out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape,
groups=groups,
grid_in=grid_in,
grid_out=grid_out,
bias=True,
)
# set up handles # set up handles
if transpose: if transpose:
...@@ -222,6 +217,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -222,6 +217,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# copy the weights from the local conv into the dist conv # copy the weights from the local conv into the dist conv
with torch.no_grad(): with torch.no_grad():
conv_dist.weight.copy_(conv_local.weight) conv_dist.weight.copy_(conv_local.weight)
if disco_args["bias"]:
conv_dist.bias.copy_(conv_local.bias) conv_dist.bias.copy_(conv_local.bias)
# create tensors # create tensors
...@@ -232,7 +228,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -232,7 +228,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
############################################################# #############################################################
# FWD pass # FWD pass
inp_full.requires_grad = True inp_full.requires_grad = True
out_full = conv_local(inp_full, use_triton_kernel=True) out_full = conv_local(inp_full)
# create grad for backward # create grad for backward
with torch.no_grad(): with torch.no_grad():
...@@ -249,11 +245,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -249,11 +245,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# FWD pass # FWD pass
inp_local = self._split_helper(inp_full) inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True inp_local.requires_grad = True
out_local = conv_dist(inp_local, use_triton_kernel=True) out_local = conv_dist(inp_local)
# BWD pass # BWD pass
ograd_local = self._split_helper(ograd_full) ograd_local = self._split_helper(ograd_full)
out_local = conv_dist(inp_local, use_triton_kernel=True) out_local = conv_dist(inp_local)
out_local.backward(ograd_local) out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone() igrad_local = inp_local.grad.clone()
...@@ -262,7 +258,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -262,7 +258,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
############################################################# #############################################################
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) ) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0: if self.world_rank == 0:
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
...@@ -272,11 +268,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -272,11 +268,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
############################################################# #############################################################
with torch.no_grad(): with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) ) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0: if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}") print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = '0.6.5' __version__ = "0.7.0a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -33,340 +33,74 @@ import math ...@@ -33,340 +33,74 @@ import math
import torch import torch
import triton try:
import triton.language as tl import disco_cuda_extension
_cuda_extension_available = True
BLOCK_SIZE_BATCH = 4 except ImportError as err:
BLOCK_SIZE_NZ = 8 disco_cuda_extension = None
BLOCK_SIZE_POUT = 8 _cuda_extension_available = False
@triton.jit
def _disco_s2_contraction_kernel(
inz_ptr,
vnz_ptr,
nnz,
inz_stride_ii,
inz_stride_nz,
vnz_stride,
x_ptr,
batch_size,
nlat_in,
nlon_in,
x_stride_b,
x_stride_t,
x_stride_p,
y_ptr,
kernel_size,
nlat_out,
nlon_out,
y_stride_b,
y_stride_f,
y_stride_t,
y_stride_p,
pscale,
backward: tl.constexpr,
BLOCK_SIZE_BATCH: tl.constexpr,
BLOCK_SIZE_NZ: tl.constexpr,
BLOCK_SIZE_POUT: tl.constexpr,
):
"""
Kernel for the sparse-dense contraction for the S2 DISCO convolution.
"""
pid_batch = tl.program_id(0)
pid_pout = tl.program_id(2)
# pid_nz should always be 0 as we do not account for larger grids in this dimension
pid_nz = tl.program_id(1) # should be always 0
tl.device_assert(pid_nz == 0)
# create the pointer block for pout
pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT)
b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
# create pointer blocks for the psi datastructure
iinz = tl.arange(0, BLOCK_SIZE_NZ)
# get the initial pointers
fout_ptrs = inz_ptr + iinz * inz_stride_nz
tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii
tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii
vals_ptrs = vnz_ptr + iinz * vnz_stride
# iterate in a blocked fashion over the non-zero entries
for offs_nz in range(0, nnz, BLOCK_SIZE_NZ):
# load input output latitude coordinate pairs
fout = tl.load(fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1)
tout = tl.load(tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1)
tpnz = tl.load(tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1)
# load corresponding values
vals = tl.load(vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0)
# compute the shifted longitude coordinates p+p' to read in a coalesced fashion
tnz = tpnz // nlon_in
pnz = tpnz % nlon_in
# make sure the value is not out of bounds
tl.device_assert(fout < kernel_size)
tl.device_assert(tout < nlat_out)
tl.device_assert(tnz < nlat_in)
tl.device_assert(pnz < nlon_in)
# load corresponding portion of the input array
x_ptrs = (
x_ptr
+ tnz[None, :, None] * x_stride_t
+ ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in) * x_stride_p
+ b[:, None, None] * x_stride_b
)
y_ptrs = (
y_ptr
+ fout[None, :, None] * y_stride_f
+ tout[None, :, None] * y_stride_t
+ (pout[None, None, :] % nlon_out) * y_stride_p
+ b[:, None, None] * y_stride_b
)
# precompute the mask
mask = ((b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz)) and (
pout[None, None, :] < nlon_out
)
# do the actual computation. Backward is essentially just the same operation with swapped tensors.
if not backward:
x = tl.load(x_ptrs, mask=mask, other=0.0)
y = vals[None, :, None] * x
# store it to the output array
tl.atomic_add(y_ptrs, y, mask=mask)
else:
y = tl.load(y_ptrs, mask=mask, other=0.0)
x = vals[None, :, None] * y
# store it to the output array
tl.atomic_add(x_ptrs, x, mask=mask)
def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
"""
Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere.
Parameters
----------
x: torch.Tensor
Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in).
psi : torch.Tensor
Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
nlon_out: int
Number of longitude points the output should have.
"""
# check the shapes of all input tensors
assert len(psi.shape) == 3
assert len(x.shape) == 4
assert psi.is_sparse, "Psi must be a sparse COO tensor"
# TODO: check that Psi is also coalesced
# get the dimensions of the problem
kernel_size, nlat_out, n_in = psi.shape
nnz = psi.indices().shape[-1]
batch_size, n_chans, nlat_in, nlon_in = x.shape
assert nlat_in * nlon_in == n_in
# TODO: check that Psi index vector is of type long
# make sure that the grid-points of the output grid fall onto the grid points of the input grid
assert nlon_in % nlon_out == 0
pscale = nlon_in // nlon_out
# to simplify things, we merge batch and channel dimensions
x = x.reshape(batch_size * n_chans, nlat_in, nlon_in)
# prepare the output tensor
y = torch.zeros(batch_size * n_chans, kernel_size, nlat_out, nlon_out, device=x.device, dtype=x.dtype)
# determine the grid for the computation
grid = (
triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
1,
triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
)
# launch the kernel
_disco_s2_contraction_kernel[grid](
psi.indices(),
psi.values(),
nnz,
psi.indices().stride(-2),
psi.indices().stride(-1),
psi.values().stride(-1),
x,
batch_size * n_chans,
nlat_in,
nlon_in,
x.stride(0),
x.stride(-2),
x.stride(-1),
y,
kernel_size,
nlat_out,
nlon_out,
y.stride(0),
y.stride(1),
y.stride(-2),
y.stride(-1),
pscale,
False,
BLOCK_SIZE_BATCH,
BLOCK_SIZE_NZ,
BLOCK_SIZE_POUT,
)
# reshape y back to expose the correct dimensions
y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out)
return y
def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int):
"""
Backward pass for the triton implementation of the efficient DISCO convolution on the sphere.
Parameters
----------
grad_y: torch.Tensor
Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out.
psi : torch.Tensor
Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in).
nlon_in: int
Number of longitude points the input used. Is required to infer the correct dimensions
"""
# check the shapes of all input tensors
assert len(psi.shape) == 3
assert len(grad_y.shape) == 5
assert psi.is_sparse, "psi must be a sparse COO tensor"
# TODO: check that Psi is also coalesced
# get the dimensions of the problem
kernel_size, nlat_out, n_in = psi.shape
nnz = psi.indices().shape[-1]
assert grad_y.shape[-2] == nlat_out
assert grad_y.shape[-3] == kernel_size
assert n_in % nlon_in == 0
nlat_in = n_in // nlon_in
batch_size, n_chans, _, _, nlon_out = grad_y.shape
# make sure that the grid-points of the output grid fall onto the grid points of the input grid
assert nlon_in % nlon_out == 0
pscale = nlon_in // nlon_out
# to simplify things, we merge batch and channel dimensions
grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out)
# prepare the output tensor
grad_x = torch.zeros(batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype)
# determine the grid for the computation
grid = (
triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH),
1,
triton.cdiv(nlon_out, BLOCK_SIZE_POUT),
)
# launch the kernel
_disco_s2_contraction_kernel[grid](
psi.indices(),
psi.values(),
nnz,
psi.indices().stride(-2),
psi.indices().stride(-1),
psi.values().stride(-1),
grad_x,
batch_size * n_chans,
nlat_in,
nlon_in,
grad_x.stride(0),
grad_x.stride(-2),
grad_x.stride(-1),
grad_y,
kernel_size,
nlat_out,
nlon_out,
grad_y.stride(0),
grad_y.stride(1),
grad_y.stride(-2),
grad_y.stride(-1),
pscale,
True,
BLOCK_SIZE_BATCH,
BLOCK_SIZE_NZ,
BLOCK_SIZE_POUT,
)
# reshape y back to expose the correct dimensions
grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in)
return grad_x
class _DiscoS2ContractionTriton(torch.autograd.Function):
"""
Helper function to make the triton implementation work with PyTorch autograd functionality
"""
class _DiscoS2ContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
ctx.save_for_backward(psi) row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
return _disco_s2_contraction_fwd(x, psi, nlon_out) return disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
(psi,) = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in) grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
grad_x = grad_psi = None ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
return grad_input, None, None return grad_input, None, None, None, None, None, None, None, None
class _DiscoS2TransposeContractionTriton(torch.autograd.Function):
"""
Helper function to make the triton implementation work with PyTorch autograd functionality
"""
class _DiscoS2TransposeContractionCuda(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
ctx.save_for_backward(psi) row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int):
ctx.save_for_backward(roff_idx, ker_idx, row_idx, col_idx, vals)
ctx.kernel_size = kernel_size
ctx.nlat_in = x.shape[-2]
ctx.nlon_in = x.shape[-1] ctx.nlon_in = x.shape[-1]
return _disco_s2_contraction_bwd(x, psi, nlon_out) return disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
(psi,) = ctx.saved_tensors roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors
grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in) grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals,
grad_x = grad_psi = None ctx.kernel_size, ctx.nlat_in, ctx.nlon_in)
return grad_input, None, None
return grad_input, None, None, None, None, None, None, None, None
def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): # CUDA
return _DiscoS2ContractionTriton.apply(x, psi, nlon_out) def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _DiscoS2ContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
kernel_size, nlat_out, nlon_out)
def _disco_s2_transpose_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_transpose_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor,
return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out) row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor,
kernel_size: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
return _DiscoS2TransposeContractionCuda.apply(x, roff_idx, ker_idx, row_idx, col_idx, vals,
kernel_size, nlat_out, nlon_out)
def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: int):
""" """
Reference implementation of the custom contraction as described in [1]. This requires repeated Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in Triton. on GPU, make sure to use the custom kernel written in CUDA.
""" """
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 4 assert len(x.shape) == 4
...@@ -402,7 +136,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl ...@@ -402,7 +136,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
""" """
Reference implementation of the custom contraction as described in [1]. This requires repeated Reference implementation of the custom contraction as described in [1]. This requires repeated
shifting of the input tensor, which can potentially be costly. For an efficient implementation shifting of the input tensor, which can potentially be costly. For an efficient implementation
on GPU, make sure to use the custom kernel written in Triton. on GPU, make sure to use the custom kernel written in CUDA.
""" """
assert len(psi.shape) == 3 assert len(psi.shape) == 3
assert len(x.shape) == 5 assert len(x.shape) == 5
...@@ -412,7 +146,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl ...@@ -412,7 +146,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
kernel_size, nlat_out, n_out = psi.shape kernel_size, nlat_out, n_out = psi.shape
assert n_out % nlon_out == 0 assert n_out % nlon_out == 0
assert nlon_out >= nlat_in assert nlon_out >= nlon_in
pscale = nlon_out // nlon_in pscale = nlon_out // nlon_in
# interleave zeros along the longitude dimension to allow for fractional offsets to be considered # interleave zeros along the longitude dimension to allow for fractional offsets to be considered
......
This diff is collapsed.
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <stdio.h>
#include <stdlib.h>
#include <torch/extension.h>
#include <cassert>
#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "disco.h"
#include <cuda_runtime.h>
#include <c10/cuda/CUDAStream.h>
#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA_INPUT_TENSOR(x) CHECK_CUDA_TENSOR(x); CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a,b) (((a)+((b)-1))/(b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
// forward kernel
torch::Tensor disco_cuda_fwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
// backward kernel
torch::Tensor disco_cuda_bwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo);
\ No newline at end of file
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "disco.h"
#include "disco_cuda.cuh"
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__device__ void disco_bwd_d(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy*K*Hi*Wi + ker*Hi*Wi + row*Wi;
out += bidy*Ho*Wo;
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*(BDIM_X*ELXTH)*pscale]
REAL_T (*__sh)[BDIM_X*ELXTH*2] = reinterpret_cast<REAL_T (*)[BDIM_X*ELXTH*2]>(__sh_ptr);
// copy current inp row in regs
REAL_T __reg[ELXTH];
#pragma unroll
for(int i = 0; i < ELXTH; i++) {
__reg[i] = (i*BDIM_X+tid < Wi) ? inp[i*BDIM_X +tid] : REAL_T(0);
}
// reset shared row up to Wo+2, remaining
// ppscale*(BDIM_X*ELXTH - Wo) locations
// will be written to but never copied to
// global mem
for(int i = 0; i < pscale; i++) {
#pragma unroll
for(int j = 0; j < 2*BDIM_X*ELXTH; j += BDIM_X) {
__sh[i][j+tid] = 0;
}
}
__syncthreads();
int col_prev = cols[soff];
int h_prev = col_prev / Wo;
int w_prev = col_prev % Wo;
// loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// we read a col that points to a new output
// row if (col / Wo) > (col_prev / Wo)
if (col >= col_prev-w_prev+Wo) {
__syncthreads();
for(int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v);
__sh[i][ j] = 0;
__sh[i][Wi + j] = 0;
}
}
__syncthreads();
col_prev = col;
h_prev = col / Wo;
w_prev = col % Wo;
}
const int w = w_prev + (col-col_prev);
const int w_mod_ps = w % pscale;
const int w_div_ps = w / pscale;
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
__sh[w_mod_ps][w_div_ps + pp] += val*__reg[i];
}
// to avoid race conditions on __sh[]
// among consecutive iterations along nz
__syncthreads();
}
__syncthreads();
// write last row
for(int i = 0; i < pscale; i++) {
for(int j = tid; j < Wi; j += BDIM_X) {
const REAL_T v = __sh[i][j] + __sh[i][Wi + j];
atomicAdd(&out[h_prev*Wo + j*pscale + i], v);
}
}
return;
}
template<int BDIM_X,
int ELXTH,
int PSCALE,
typename REAL_T>
__global__ __launch_bounds__(BDIM_X)
void disco_bwd_blk_k(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
if constexpr(PSCALE != 0) { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, PSCALE, roff, kers, rows, cols, vals, inp, out); }
else { disco_bwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); }
return;
}
template<int NTH,
int ELXTH,
typename REAL_T>
static void launch_kernel(int BC,
int Hi,
int Wi,
int K,
int Ho,
int Wo,
int64_t nrows,
int64_t *roff_d,
int64_t *ker_d,
int64_t *row_d,
int64_t *col_d,
REAL_T *val_d,
REAL_T *inp_d,
REAL_T *out_d,
cudaStream_t stream) {
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wi) {
dim3 grid(nrows, BC);
const int pscale = Wo/Wi;
size_t shmem = sizeof(*out_d)*(2 * (NTH*ELXTH)*pscale);
switch(pscale) {
case 1:
disco_bwd_blk_k<NTH, ELXTH, 1><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
case 2:
disco_bwd_blk_k<NTH, ELXTH, 2><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
case 3:
disco_bwd_blk_k<NTH, ELXTH, 3><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
break;
default:
disco_bwd_blk_k<NTH, ELXTH, 0><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
}
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
stream);
}
}
return;
}
torch::Tensor disco_cuda_bwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(3);
int64_t Wi = inp.size(4);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX%2));
if (Wo <= 64*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_backward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
//}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "disco.h"
#include "disco_cuda.cuh"
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__device__ void disco_fwd_d(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
const int tid = threadIdx.x;
const int64_t bidx = blockIdx.x; // gloabl row
const int64_t bidy = blockIdx.y; // bc
int64_t soff = roff[bidx];
int64_t eoff = roff[bidx+1];
const int64_t ker = kers[soff];
const int64_t row = rows[soff];
inp += bidy*Hi*Wi;
out += bidy*K*Ho*Wo + ker*Ho*Wo + row*Wo;
REAL_T __reg[ELXTH] = {0};
// align to larger supported fp type
extern __shared__ __align__(sizeof(double)) unsigned char __sh_ptr[]; // REAL_T __sh[2*Wi + ppscale*(BDIM_X*ELXTH - Wo)]
REAL_T *__sh = reinterpret_cast<REAL_T *>(__sh_ptr);
int col_prev = cols[soff];
int h_prev = col_prev / Wi;
int w_prev = col_prev % Wi;
// copy current inp row in shmem
for(int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i];
__sh[ i] = v;
__sh[Wi + i] = v;
}
// locations __sh[2*Wi : ppscale*(BDIM_X*ELXTH-Wo)] are not used
__syncthreads();
// loops along the colums of CTA's row
for(int64_t nz = soff; nz < eoff; nz++) {
const int col = cols[nz];
const REAL_T val = vals[nz];
// if we are processing a nz with a col value
// leading to a new row of inp then copy it
// to shmem;
// checks whether (h_prev < h) with:
// (col >= col_prev - (col_prev % Wi) + Wi)
if (col >= col_prev-w_prev+Wi) {
col_prev = col;
h_prev = col / Wi;
w_prev = col % Wi;
__syncthreads();
for(int i = tid; i < Wi; i += BDIM_X) {
const REAL_T v = inp[h_prev*Wi + i];
__sh[ i] = v;
__sh[Wi + i] = v;
}
__syncthreads();
}
const int w = w_prev + (col-col_prev);
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
// original lines:
//
// if (pp >= Wo) break;
// const int wpp = (w + pscale*pp) % Wi;
//
// value of (w + pscale*pp) < (Wi + (Wi/Wo)*Wo) = 2*Wi
// so we can allocate twice the amount of shmem,
// replicate the current inp row and avoid the costly mod
//
// also, to avoid the conditional, sh can be extended to
// cover the maximum location accessed during this loop
//
// REAL_T __sh[2*Wi + ppscale*NUM_REM]
//
// Wi + (Wi/Wo)*BDIM_X*ELXTH = (since BDIM_X*ELXTH >= Wo) =
// = Wi + (Wi/Wo)*(Wo + (BDIM_X*ELXTH - Wo)) =
// = 2*Wi + ppscale*NUM_REM
//
// with NUM_REM = BDIM_X*ELXTH - Wo
const int wpp = w + pscale*pp;
__reg[i] += val*__sh[wpp];
}
}
#pragma unroll
for (int i = 0; i < ELXTH; i++) {
const int pp = i*BDIM_X + tid;
if (pp >= Wo) break;
out[pp] = __reg[i];
}
return;
}
template<int BDIM_X,
int ELXTH,
typename REAL_T>
__global__ __launch_bounds__(BDIM_X)
void disco_fwd_blk_k(const int Hi,
const int Wi,
const int K,
const int Ho,
const int Wo,
const int pscale,
const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers,
const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols,
const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp,
REAL_T *__restrict__ out) {
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
return;
}
template<int NTH,
int ELXTH,
typename REAL_T>
static void launch_kernel(int BC,
int Hi,
int Wi,
int K,
int Ho,
int Wo,
int64_t nrows,
int64_t *roff_d,
int64_t *ker_d,
int64_t *row_d,
int64_t *col_d,
REAL_T *val_d,
REAL_T *inp_d,
REAL_T *out_d,
cudaStream_t stream) {
static_assert(sizeof(REAL_T) == 2 ||
sizeof(REAL_T) == 4 ||
sizeof(REAL_T) == 8);
if constexpr(ELXTH <= ELXTH_MAX) {
if (NTH*ELXTH >= Wo) {
dim3 grid(nrows, BC);
const int pscale = Wi/Wo;
size_t shmem = sizeof(*out_d)*(Wi*2 + pscale*(NTH*ELXTH-Wo));
disco_fwd_blk_k<NTH, ELXTH><<<grid, NTH, shmem, stream>>>(Hi, Wi,
K, Ho, Wo, pscale,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d);
} else {
launch_kernel<NTH, ELXTH+1>(BC,
Hi, Wi,
K, Ho, Wo,
nrows,
roff_d,
ker_d, row_d, col_d, val_d,
inp_d, out_d,
stream);
}
}
return;
}
torch::Tensor disco_cuda_fwd(torch::Tensor inp,
torch::Tensor roff_idx,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val,
int64_t K,
int64_t Ho,
int64_t Wo) {
// some sanity checks
CHECK_CUDA_INPUT_TENSOR(inp);
CHECK_CUDA_INPUT_TENSOR(roff_idx);
CHECK_CUDA_INPUT_TENSOR(ker_idx);
CHECK_CUDA_INPUT_TENSOR(row_idx);
CHECK_CUDA_INPUT_TENSOR(col_idx);
CHECK_CUDA_INPUT_TENSOR(val);
// extract some shapes
int64_t B = inp.size(0);
int64_t C = inp.size(1);
int64_t BC = B * C;
int64_t Hi = inp.size(2);
int64_t Wi = inp.size(3);
int64_t nrows = roff_idx.size(0) - 1;
// allocate output
int64_t out_dims[] = {B, C, K, Ho, Wo};
auto options = torch::TensorOptions().device(inp.device()).dtype(inp.dtype());
torch::Tensor out = torch::zeros(out_dims, options);
// get stream
auto stream = at::cuda::getCurrentCUDAStream().stream();
// assert
static_assert(0 == (ELXTH_MAX%2));
// pick the correct launch config
if (Wo <= 64*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<64, 1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 128*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<128, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 256*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<256, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 512*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<512, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else if (Wo <= 1024*ELXTH_MAX) {
AT_DISPATCH_FLOATING_TYPES(inp.scalar_type(), "disco_forward_cuda", ([&] {
launch_kernel<1024, (ELXTH_MAX/2)+1, scalar_t>(BC, Hi, Wi, K, Ho, Wo, nrows,
roff_idx.data_ptr<int64_t>(),
ker_idx.data_ptr<int64_t>(),
row_idx.data_ptr<int64_t>(),
col_idx.data_ptr<int64_t>(),
val.data_ptr<scalar_t>(),
inp.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(),
stream);
}));
}
else {
fprintf(stderr,
"%s:%d: error, unsupported Wo value (%ld), max supported is %d\n",
__FILE__, __LINE__, Wo, 1024*ELXTH_MAX);
exit(EXIT_FAILURE);
}
return out;
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "disco.h"
template<typename REAL_T>
void preprocess_psi_kernel(int64_t nnz,
int64_t K,
int64_t Ho,
int64_t *ker_h,
int64_t *row_h,
int64_t *col_h,
int64_t *roff_h,
REAL_T *val_h,
int64_t& nrows) {
int64_t *Koff = new int64_t[K];
for(int i = 0; i < K; i++) {
Koff[i] = 0;
}
for(int64_t i = 0; i < nnz; i++) {
Koff[ker_h[i]]++;
}
int64_t prev = Koff[0];
Koff[0] = 0;
for(int i = 1; i < K; i++) {
int64_t save = Koff[i];
Koff[i] = prev + Koff[i-1];
prev = save;
}
int64_t *ker_sort = new int64_t[nnz];
int64_t *row_sort = new int64_t[nnz];
int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz];
for(int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++;
ker_sort[off] = ker;
row_sort[off] = row_h[i];
col_sort[off] = col_h[i];
val_sort[off] = val_h[i];
}
for(int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i];
col_h[i] = col_sort[i];
val_h[i] = val_sort[i];
}
delete [] Koff;
delete [] ker_sort;
delete [] row_sort;
delete [] col_sort;
delete [] val_sort;
// compute rows offsets
nrows = 1;
roff_h[0] = 0;
for(int64_t i = 1; i < nnz; i++) {
if (row_h[i-1] == row_h[i]) continue;
roff_h[nrows++] = i;
if (nrows > Ho*K) {
fprintf(stderr,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n",
__FILE__, __LINE__, int64_t(Ho)*K);
exit(EXIT_FAILURE);
}
}
roff_h[nrows] = nnz;
return;
}
torch::Tensor preprocess_psi(const int64_t K,
const int64_t Ho,
torch::Tensor ker_idx,
torch::Tensor row_idx,
torch::Tensor col_idx,
torch::Tensor val) {
CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx);
CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val);
int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho*K+1];
int64_t nrows;
//float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&]{
preprocess_psi_kernel<scalar_t>(nnz, K, Ho,
ker_h,
row_h,
col_h,
roff_h,
val.data_ptr<scalar_t>(),
nrows);
}));
// create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows+1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for(int64_t i = 0; i < (nrows+1); i++) {
roff_out_h[i] = roff_h[i];
}
delete [] roff_h;
return roff_idx;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
}
// coding=utf-8
//
// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "disco.h"
#include "disco_cuda.cuh"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &disco_cuda_fwd, "DISCO forward (CUDA)");
m.def("backward", &disco_cuda_bwd, "DISCO backward (CUDA)");
}
...@@ -33,7 +33,14 @@ ...@@ -33,7 +33,14 @@
from .utils import init, is_initialized, polar_group, azimuth_group from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import compute_split_shapes, split_tensor_along_dim from .primitives import compute_split_shapes, split_tensor_along_dim
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar, reduce_from_polar_region, scatter_to_polar_region from .primitives import (
distributed_transpose_azimuth,
distributed_transpose_polar,
reduce_from_polar_region,
scatter_to_polar_region,
gather_from_polar_region,
copy_to_polar_region
)
# import the sht # import the sht
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment