Unverified Commit 663bea1f authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

safeguarding all custom CUDA and C++ routines via the _cuda_extension… (#54)

* safeguarding all custom CUDA and C++ routines via the _cuda_extension_available flag

* bumping up version number
parent 4fea88bf
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
* Added resampling modules for convenience * Added resampling modules for convenience
* Changing behavior of distributed SHT to use `dim=-3` as channel dimension * Changing behavior of distributed SHT to use `dim=-3` as channel dimension
* Fixing SHT unittests to test SHT and ISHT individually, rather than the roundtrip * Fixing SHT unittests to test SHT and ISHT individually, rather than the roundtrip
* Changing the way custom CUDA extensions are handled
### v0.7.1 ### v0.7.1
......
...@@ -75,16 +75,16 @@ torch-harmonics has been used to implement a variety of differentiable PDE solve ...@@ -75,16 +75,16 @@ torch-harmonics has been used to implement a variety of differentiable PDE solve
## Installation ## Installation
Download directly from PyPI: A simple installation can be directly done from PyPI:
```bash ```bash
pip install torch-harmonics pip install torch-harmonics
``` ```
If you would like to enforce the compilation of CUDA extensions for the discrete-continuous convolutions, you can do so by setting the `FORCE_CUDA_EXTENSION` flag. You may also want to set appropriate architectures with the `TORCH_CUDA_ARCH_LIST` flag. If you are planning to use spherical convolutions, we recommend building the corresponding custom CUDA kernels. To enforce this, you can set the `FORCE_CUDA_EXTENSION` flag. You may also want to set appropriate architectures with the `TORCH_CUDA_ARCH_LIST` flag. Finally, make sure to disable build isolation via the `--no-build-isolation` flag to ensure that the custom kernels are built with the existing torch installation.
```bash ```bash
export FORCE_CUDA_EXTENSION=1 export FORCE_CUDA_EXTENSION=1
export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
pip install torch-harmonics pip install --no-build-isolation torch-harmonics
``` ```
:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0. :warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0.
......
[build-system] [build-system]
requires = [ "setuptools", "setuptools-scm", "torch>=2.4.0"] requires = [ "setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.1" __version__ = "0.7.2"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
...@@ -44,12 +44,10 @@ from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes ...@@ -44,12 +44,10 @@ from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
# import custom C++/CUDA extensions # import custom C++/CUDA extensions if available
from disco_helpers import preprocess_psi
try: try:
from disco_helpers import preprocess_psi
import disco_cuda_extension import disco_cuda_extension
_cuda_extension_available = True _cuda_extension_available = True
except ImportError as err: except ImportError as err:
disco_cuda_extension = None disco_cuda_extension = None
...@@ -377,10 +375,13 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -377,10 +375,13 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
vals = vals.contiguous() vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel if _cuda_extension_available:
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
...@@ -468,10 +469,13 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -468,10 +469,13 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
vals = vals.contiguous() vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel if _cuda_extension_available:
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
......
...@@ -58,12 +58,10 @@ from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_pol ...@@ -58,12 +58,10 @@ from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_pol
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
# import custom C++/CUDA extensions # import custom C++/CUDA extensions if available
from disco_helpers import preprocess_psi
try: try:
from disco_helpers import preprocess_psi
import disco_cuda_extension import disco_cuda_extension
_cuda_extension_available = True _cuda_extension_available = True
except ImportError as err: except ImportError as err:
disco_cuda_extension = None disco_cuda_extension = None
...@@ -240,10 +238,12 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -240,10 +238,12 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
vals = vals.contiguous() vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel if _cuda_extension_available:
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
...@@ -370,10 +370,12 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -370,10 +370,12 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
row_idx = idx[1, ...].contiguous() row_idx = idx[1, ...].contiguous()
col_idx = idx[2, ...].contiguous() col_idx = idx[2, ...].contiguous()
vals = vals.contiguous() vals = vals.contiguous()
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
# preprocessed data-structure for GPU kernel if _cuda_extension_available:
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) # preprocessed data-structure for GPU kernel
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment