Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
12422dd1
Unverified
Commit
12422dd1
authored
Nov 14, 2020
by
Gao, Xiang
Committed by
GitHub
Nov 14, 2020
Browse files
More about cuaev (#547)
* More about cuaev * Update setup.py * save * clang-format
parent
e67a2ad5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
20 deletions
+73
-20
setup.py
setup.py
+1
-2
tests/test_cuaev.py
tests/test_cuaev.py
+20
-4
torchani/__init__.py
torchani/__init__.py
+3
-12
torchani/aev.py
torchani/aev.py
+36
-1
torchani/cuaev/aev.cu
torchani/cuaev/aev.cu
+12
-0
torchani/testing.py
torchani/testing.py
+1
-1
No files found.
setup.py
View file @
12422dd1
...
...
@@ -61,8 +61,7 @@ def cuda_extension():
pkg
=
'torchani.cuaev'
,
sources
=
glob
.
glob
(
'torchani/cuaev/*'
),
include_dirs
=
maybe_download_cub
(),
extra_compile_args
=
{
'cxx'
:
[
'-std=c++14'
],
'nvcc'
:
nvcc_args
},
optional
=
True
)
extra_compile_args
=
{
'cxx'
:
[
'-std=c++14'
],
'nvcc'
:
nvcc_args
})
def
cuaev_kwargs
():
...
...
tests/test_cuaev.py
View file @
12422dd1
import
torchani
import
unittest
import
torch
import
os
from
torchani.testing
import
TestCase
,
make_tensor
skipIfNoGPU
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'There is no device to run this test'
)
@
unittest
.
skipIf
(
not
torchani
.
has_cuaev
,
"only valid when cuaev is installed"
)
class
TestCUAEV
(
torchani
.
testing
.
TestCase
):
@
unittest
.
skipIf
(
not
torchani
.
aev
.
has_cuaev
,
"only valid when cuaev is installed"
)
class
TestCUAEV
NoGPU
(
TestCase
):
def
test
JIT
(
self
):
def
test
Simple
(
self
):
def
f
(
coordinates
,
species
,
Rcr
:
float
,
Rca
:
float
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
:
int
):
return
torch
.
ops
.
cuaev
.
cuComputeAEV
(
coordinates
,
species
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
)
s
=
torch
.
jit
.
script
(
f
)
self
.
assertIn
(
"cuaev::cuComputeAEV"
,
str
(
s
.
graph
))
@
skipIfNoGPU
def
testAEVComputer
(
self
):
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
const_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params'
)
# noqa: E501
consts
=
torchani
.
neurochem
.
Constants
(
const_file
)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
,
use_cuda_extension
=
True
)
s
=
torch
.
jit
.
script
(
aev_computer
)
# Computation of AEV using cuaev when there is no atoms does not require CUDA, and can be run without GPU
species
=
make_tensor
((
8
,
0
),
'cpu'
,
torch
.
int64
,
low
=-
1
,
high
=
4
)
coordinates
=
make_tensor
((
8
,
0
,
3
),
'cpu'
,
torch
.
float32
,
low
=-
5
,
high
=
5
)
self
.
assertIn
(
"cuaev::cuComputeAEV"
,
str
(
s
.
graph_for
((
species
,
coordinates
))))
@
unittest
.
skipIf
(
not
torchani
.
aev
.
has_cuaev
,
"only valid when cuaev is installed"
)
@
skipIfNoGPU
class
TestCUAEV
(
TestCase
):
def
testHello
(
self
):
pass
...
...
torchani/__init__.py
View file @
12422dd1
...
...
@@ -38,17 +38,8 @@ from . import models
from
.
import
units
from
pkg_resources
import
get_distribution
,
DistributionNotFound
import
warnings
import
importlib_metadata
from
.
import
testing
has_cuaev
=
'torchani.cuaev'
in
importlib_metadata
.
metadata
(
__package__
).
get_all
(
'Provides'
)
if
has_cuaev
:
# We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
from
.
import
cuaev
# type: ignore # noqa: F401
else
:
warnings
.
warn
(
"cuaev not installed"
)
try
:
__version__
=
get_distribution
(
__name__
).
version
except
DistributionNotFound
:
...
...
@@ -56,16 +47,16 @@ except DistributionNotFound:
pass
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'SpeciesConverter'
,
'utils'
,
'neurochem'
,
'models'
,
'units'
,
'has_cuaev'
,
'testing'
]
'utils'
,
'neurochem'
,
'models'
,
'units'
,
'testing'
]
try
:
from
.
import
ase
# noqa: F401
__all__
.
append
(
'ase'
)
except
ImportError
:
pass
warnings
.
warn
(
"Dependency not satisfied, torchani.ase will not be available"
)
try
:
from
.
import
data
# noqa: F401
__all__
.
append
(
'data'
)
except
ImportError
:
pass
warnings
.
warn
(
"Dependency not satisfied, torchani.data will not be available"
)
torchani/aev.py
View file @
12422dd1
...
...
@@ -4,6 +4,16 @@ from torch import Tensor
import
math
from
typing
import
Tuple
,
Optional
,
NamedTuple
import
sys
import
warnings
import
importlib_metadata
has_cuaev
=
'torchani.cuaev'
in
importlib_metadata
.
metadata
(
__package__
).
get_all
(
'Provides'
)
if
has_cuaev
:
# We need to import torchani.cuaev to tell PyTorch to initialize torch.ops.cuaev
from
.
import
cuaev
# type: ignore # noqa: F401
else
:
warnings
.
warn
(
"cuaev not installed"
)
if
sys
.
version_info
[:
2
]
<
(
3
,
7
):
class
FakeFinal
:
...
...
@@ -314,6 +324,20 @@ def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
return
torch
.
cat
([
radial_aev
,
angular_aev
],
dim
=-
1
)
def
compute_cuaev
(
species
:
Tensor
,
coordinates
:
Tensor
,
triu_index
:
Tensor
,
constants
:
Tuple
[
float
,
Tensor
,
Tensor
,
float
,
Tensor
,
Tensor
,
Tensor
,
Tensor
],
num_species
:
int
,
cell_shifts
:
Optional
[
Tuple
[
Tensor
,
Tensor
]])
->
Tensor
:
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
=
constants
assert
cell_shifts
is
None
,
"Current implementation of cuaev does not support pbc."
species_int
=
species
.
to
(
torch
.
int32
)
return
torch
.
ops
.
cuaev
.
cuComputeAEV
(
coordinates
,
species_int
,
Rcr
,
Rca
,
EtaR
.
flatten
(),
ShfR
.
flatten
(),
EtaA
.
flatten
(),
Zeta
.
flatten
(),
ShfA
.
flatten
(),
ShfZ
.
flatten
(),
num_species
)
if
not
has_cuaev
:
compute_cuaev
=
torch
.
jit
.
unused
(
compute_cuaev
)
class
AEVComputer
(
torch
.
nn
.
Module
):
r
"""The AEV computer that takes coordinates as input and outputs aevs.
...
...
@@ -350,14 +374,20 @@ class AEVComputer(torch.nn.Module):
aev_length
:
Final
[
int
]
sizes
:
Final
[
Tuple
[
int
,
int
,
int
,
int
,
int
]]
triu_index
:
Tensor
use_cuda_extension
:
Final
[
bool
]
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
,
use_cuda_extension
=
False
):
super
().
__init__
()
self
.
Rcr
=
Rcr
self
.
Rca
=
Rca
assert
Rca
<=
Rcr
,
"Current implementation of AEVComputer assumes Rca <= Rcr"
self
.
num_species
=
num_species
# cuda aev
if
use_cuda_extension
:
assert
has_cuaev
,
"AEV cuda extension is not installed"
self
.
use_cuda_extension
=
use_cuda_extension
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self
.
register_buffer
(
'EtaR'
,
EtaR
.
view
(
-
1
,
1
))
...
...
@@ -474,6 +504,11 @@ class AEVComputer(torch.nn.Module):
assert
species
.
shape
==
coordinates
.
shape
[:
-
1
]
assert
coordinates
.
shape
[
-
1
]
==
3
if
self
.
use_cuda_extension
:
assert
(
cell
is
None
and
pbc
is
None
),
"cuaev does not support PBC"
aev
=
compute_cuaev
(
species
,
coordinates
,
self
.
triu_index
,
self
.
constants
(),
self
.
num_species
,
None
)
return
SpeciesAEV
(
species
,
aev
)
if
cell
is
None
and
pbc
is
None
:
aev
=
compute_aev
(
species
,
coordinates
,
self
.
triu_index
,
self
.
constants
(),
self
.
sizes
,
None
)
else
:
...
...
torchani/cuaev/aev.c
pp
→
torchani/cuaev/aev.c
u
View file @
12422dd1
#include <torch/extension.h>
#include <ATen/Context.h>
#include <c10/cuda/CUDACachingAllocator.h>
__global__
void
run
()
{
printf
(
"Hello World"
);
}
template
<
typename
ScalarRealT
=
float
>
torch
::
Tensor
cuComputeAEV
(
torch
::
Tensor
coordinates_t
,
torch
::
Tensor
species_t
,
double
Rcr_
,
double
Rca_
,
torch
::
Tensor
EtaR_t
,
...
...
@@ -9,6 +14,13 @@ torch::Tensor cuComputeAEV(torch::Tensor coordinates_t, torch::Tensor species_t,
ScalarRealT
Rcr
=
Rcr_
;
ScalarRealT
Rca
=
Rca_
;
int
num_species
=
num_species_
;
if
(
species_t
.
numel
()
==
0
)
{
return
coordinates_t
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
run
<<<
1
,
1
,
0
,
stream
>>>
();
return
coordinates_t
;
}
TORCH_LIBRARY
(
cuaev
,
m
)
{
m
.
def
(
"cuComputeAEV"
,
&
cuComputeAEV
<
float
>
);
}
...
...
torchani/testing.py
View file @
12422dd1
from
torch.testing._internal.common_utils
import
TestCase
# noqa: F401
from
torch.testing._internal.common_utils
import
TestCase
,
make_tensor
# noqa: F401
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment