Unverified Commit 12422dd1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

More about cuaev (#547)

* More about cuaev

* Update setup.py

* save

* clang-format
parent e67a2ad5
...@@ -61,8 +61,7 @@ def cuda_extension(): ...@@ -61,8 +61,7 @@ def cuda_extension():
pkg='torchani.cuaev', pkg='torchani.cuaev',
sources=glob.glob('torchani/cuaev/*'), sources=glob.glob('torchani/cuaev/*'),
include_dirs=maybe_download_cub(), include_dirs=maybe_download_cub(),
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args}, extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args})
optional=True)
def cuaev_kwargs(): def cuaev_kwargs():
......
import torchani import torchani
import unittest import unittest
import torch import torch
import os
from torchani.testing import TestCase, make_tensor
skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(), skipIfNoGPU = unittest.skipIf(not torch.cuda.is_available(),
'There is no device to run this test') 'There is no device to run this test')
@unittest.skipIf(not torchani.has_cuaev, "only valid when cuaev is installed") @unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")
class TestCUAEV(torchani.testing.TestCase): class TestCUAEVNoGPU(TestCase):
def testJIT(self): def testSimple(self):
def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int): def f(coordinates, species, Rcr: float, Rca: float, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species: int):
return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species) return torch.ops.cuaev.cuComputeAEV(coordinates, species, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
s = torch.jit.script(f) s = torch.jit.script(f)
self.assertIn("cuaev::cuComputeAEV", str(s.graph)) self.assertIn("cuaev::cuComputeAEV", str(s.graph))
@skipIfNoGPU def testAEVComputer(self):
path = os.path.dirname(os.path.realpath(__file__))
const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params') # noqa: E501
consts = torchani.neurochem.Constants(const_file)
aev_computer = torchani.AEVComputer(**consts, use_cuda_extension=True)
s = torch.jit.script(aev_computer)
# Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU
species = make_tensor((8, 0), 'cpu', torch.int64, low=-1, high=4)
coordinates = make_tensor((8, 0, 3), 'cpu', torch.float32, low=-5, high=5)
self.assertIn("cuaev::cuComputeAEV", str(s.graph_for((species, coordinates))))
@unittest.skipIf(not torchani.aev.has_cuaev, "only valid when cuaev is installed")
@skipIfNoGPU
class TestCUAEV(TestCase):
def testHello(self): def testHello(self):
pass pass
......
...@@ -38,17 +38,8 @@ from . import models ...@@ -38,17 +38,8 @@ from . import models
from . import units from . import units
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import warnings import warnings
import importlib_metadata
from . import testing from . import testing
has_cuaev = 'torchani.cuaev' in importlib_metadata.metadata(__package__).get_all('Provides')
if has_cuaev:
# We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
from . import cuaev # type: ignore # noqa: F401
else:
warnings.warn("cuaev not installed")
try: try:
__version__ = get_distribution(__name__).version __version__ = get_distribution(__name__).version
except DistributionNotFound: except DistributionNotFound:
...@@ -56,16 +47,16 @@ except DistributionNotFound: ...@@ -56,16 +47,16 @@ except DistributionNotFound:
pass pass
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter', __all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter',
'utils', 'neurochem', 'models', 'units', 'has_cuaev', 'testing'] 'utils', 'neurochem', 'models', 'units', 'testing']
try: try:
from . import ase # noqa: F401 from . import ase # noqa: F401
__all__.append('ase') __all__.append('ase')
except ImportError: except ImportError:
pass warnings.warn("Dependency not satisfied, torchani.ase will not be available")
try: try:
from . import data # noqa: F401 from . import data # noqa: F401
__all__.append('data') __all__.append('data')
except ImportError: except ImportError:
pass warnings.warn("Dependency not satisfied, torchani.data will not be available")
...@@ -4,6 +4,16 @@ from torch import Tensor ...@@ -4,6 +4,16 @@ from torch import Tensor
import math import math
from typing import Tuple, Optional, NamedTuple from typing import Tuple, Optional, NamedTuple
import sys import sys
import warnings
import importlib_metadata
has_cuaev = 'torchani.cuaev' in importlib_metadata.metadata(__package__).get_all('Provides')
if has_cuaev:
# We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
from . import cuaev # type: ignore # noqa: F401
else:
warnings.warn("cuaev not installed")
if sys.version_info[:2] < (3, 7): if sys.version_info[:2] < (3, 7):
class FakeFinal: class FakeFinal:
...@@ -314,6 +324,20 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor, ...@@ -314,6 +324,20 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
def compute_cuaev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
num_species: int, cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
assert cell_shifts is None, "Current implementation of cuaev does not support pbc."
species_int = species.to(torch.int32)
return torch.ops.cuaev.cuComputeAEV(coordinates, species_int, Rcr, Rca, EtaR.flatten(), ShfR.flatten(), EtaA.flatten(), Zeta.flatten(), ShfA.flatten(), ShfZ.flatten(), num_species)
if not has_cuaev:
compute_cuaev = torch.jit.unused(compute_cuaev)
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs. r"""The AEV computer that takes coordinates as input and outputs aevs.
...@@ -350,14 +374,20 @@ class AEVComputer(torch.nn.Module): ...@@ -350,14 +374,20 @@ class AEVComputer(torch.nn.Module):
aev_length: Final[int] aev_length: Final[int]
sizes: Final[Tuple[int, int, int, int, int]] sizes: Final[Tuple[int, int, int, int, int]]
triu_index: Tensor triu_index: Tensor
use_cuda_extension: Final[bool]
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species): def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension=False):
super().__init__() super().__init__()
self.Rcr = Rcr self.Rcr = Rcr
self.Rca = Rca self.Rca = Rca
assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr" assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr"
self.num_species = num_species self.num_species = num_species
# cuda aev
if use_cuda_extension:
assert has_cuaev, "AEV cuda extension is not installed"
self.use_cuda_extension = use_cuda_extension
# convert constant tensors to a ready-to-broadcast shape # convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR) # shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1)) self.register_buffer('EtaR', EtaR.view(-1, 1))
...@@ -474,6 +504,11 @@ class AEVComputer(torch.nn.Module): ...@@ -474,6 +504,11 @@ class AEVComputer(torch.nn.Module):
assert species.shape == coordinates.shape[:-1] assert species.shape == coordinates.shape[:-1]
assert coordinates.shape[-1] == 3 assert coordinates.shape[-1] == 3
if self.use_cuda_extension:
assert (cell is None and pbc is None), "cuaev does not support PBC"
aev = compute_cuaev(species, coordinates, self.triu_index, self.constants(), self.num_species, None)
return SpeciesAEV(species, aev)
if cell is None and pbc is None: if cell is None and pbc is None:
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None) aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
else: else:
......
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/Context.h>
#include <c10/cuda/CUDACachingAllocator.h>
__global__ void run() { printf("Hello World"); }
template <typename ScalarRealT = float> template <typename ScalarRealT = float>
torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
double Rcr_, double Rca_, torch::Tensor EtaR_t, double Rcr_, double Rca_, torch::Tensor EtaR_t,
...@@ -9,6 +14,13 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t, ...@@ -9,6 +14,13 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
ScalarRealT Rcr = Rcr_; ScalarRealT Rcr = Rcr_;
ScalarRealT Rca = Rca_; ScalarRealT Rca = Rca_;
int num_species = num_species_; int num_species = num_species_;
if (species_t.numel() == 0) {
return coordinates_t;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
run<<<1, 1, 0, stream>>>();
return coordinates_t;
} }
TORCH_LIBRARY(cuaev, m) { m.def("cuComputeAEV", &cuComputeAEV<float>); } TORCH_LIBRARY(cuaev, m) { m.def("cuComputeAEV", &cuComputeAEV<float>); }
......
from torch.testing._internal.common_utils import TestCase # noqa: F401 from torch.testing._internal.common_utils import TestCase, make_tensor # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment