Unverified Commit dec9c0eb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make CUDA extension TorchScript compatible (#527)

* Make CUDA extension TorchScript compatible

* save

* save

* fix

* save

* save

* save

* save

* Update install_dependencies.sh

* trigger ci

* save

* fix

* save

* save

* save

* save

* try

* fix

* revert

* save

* save

* mypy
parent 0b354c1d
*.prof
__pycache__
/data
*.cpp
a.out
/test.py
/.vscode
......
import os
import glob
import subprocess
from setuptools import setup, find_packages
from distutils import log
......@@ -56,9 +57,9 @@ def cuda_extension():
if cuda_version >= 11.1:
nvcc_args.append("-gencode=arch=compute_86,code=sm_86")
return CUDAExtension(
name='_real_cuaev',
pkg='torchani.cuaev._real_cuaev',
sources=['torchani/cuaev/aev.cu'],
name='torchani.cuaev',
pkg='torchani.cuaev',
sources=glob.glob('torchani/cuaev/*'),
include_dirs=maybe_download_cub(),
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args},
optional=True)
......
......@@ -6,20 +6,18 @@ skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
'There is no device to run this test')
@unittest.skipIf(torchani.cuaev.is_installed, "only valid when cuaev not installed")
class TestCUAEVNotInstalled(unittest.TestCase):
def testCuComputeAEV(self):
self.assertRaisesRegex(RuntimeError, "cuaev is not installed", lambda: torchani.cuaev.cuComputeAEV())
@unittest.skipIf(not torchani.cuaev.is_installed, "only valid when cuaev is installed")
@unittest.skipIf(not torchani.has_cuaev, "only valid when cuaev is installed")
class TestCUAEV(unittest.TestCase):
def testJIT(self):
def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int):
return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
s = torch.jit.script(f)
self.assertIn("cuaev::cuComputeAEV", str(s.graph))
@skipIfNoGPU
def testHello(self):
# TODO: this should be removed when a real cuaev is merged
self.assertEqual("Hello World!!!", torchani.cuaev.cuComputeAEV())
pass
if __name__ == '__main__':
......
......@@ -36,8 +36,17 @@ from . import utils
from . import neurochem
from . import models
from . import units
from . import cuaev
from pkg_resources import get_distribution, DistributionNotFound
import warnings
import importlib_metadata
has_cuaev = 'torchani.cuaev' in importlib_metadata.metadata(__package__).get_all('Provides')
if has_cuaev:
# We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
from . import cuaev # type: ignore # noqa: F401
else:
warnings.warn("cuaev not installed")
try:
__version__ = get_distribution(__name__).version
......@@ -46,7 +55,7 @@ except DistributionNotFound:
pass
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter',
'utils', 'neurochem', 'models', 'units', 'cuaev']
'utils', 'neurochem', 'models', 'units', 'has_cuaev']
try:
from . import ase # noqa: F401
......
import warnings
import importlib_metadata
is_installed = 'torchani.cuaev' in importlib_metadata.metadata('torchani').get_all('Provides')
if is_installed:
import _real_cuaev
cuComputeAEV = _real_cuaev.cuComputeAEV
else:
warnings.warn("cuaev not installed")
def cuComputeAEV(*args, **kwargs):
raise RuntimeError("cuaev is not installed")
#include <torch/extension.h>
template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
double Rcr_, double Rca_, torch::Tensor EtaR_t,
torch::Tensor ShfR_t, torch::Tensor EtaA_t,
torch::Tensor Zeta_t, torch::Tensor ShfA_t,
torch::Tensor ShfZ_t, int64_t num_species_) {
ScalarRealT Rcr = Rcr_;
ScalarRealT Rca = Rca_;
int num_species = num_species_;
}
TORCH_LIBRARY(cuaev, m) { m.def("cuComputeAEV", &cuComputeAEV<float>); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
#include <cub/cub.cuh>
#include <string>
#include <torch/extension.h>
__global__ void kernel() { printf("Hello World!"); }
std::string say_hello() {
kernel<<<1, 1>>>();
return "Hello World!!!";
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cuComputeAEV", &say_hello, "Hello World");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment