Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
5bac472d
"test/test_transforms_v2_functional.py" did not exist on "c0911e31d367eca9600541bd634564c73753abbf"
Unverified
Commit
5bac472d
authored
Mar 23, 2019
by
Gao, Xiang
Committed by
GitHub
Mar 23, 2019
Browse files
Change default dtype of ase to float64 (#191)
parent
e3096aa1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
14 deletions
+33
-14
examples/ase_interface.py
examples/ase_interface.py
+15
-0
tests/test_ase.py
tests/test_ase.py
+3
-3
torchani/aev.py
torchani/aev.py
+1
-1
torchani/ase.py
torchani/ase.py
+8
-6
torchani/models.py
torchani/models.py
+6
-4
No files found.
examples/ase_interface.py
View file @
5bac472d
...
...
@@ -32,6 +32,21 @@ print(len(atoms), "atoms in the cell")
###############################################################################
# Now let's create a calculator from builtin models:
calculator
=
torchani
.
models
.
ANI1ccx
().
ase
()
###############################################################################
# .. note::
# Regardless of the dtype you use in your model, when converting it to ASE
# calculator, it always automatically the dtype to ``torch.float64``. The
# reason for this behavior is, at many cases, the rounding error is too
# large for structure minimization. If you insist on using
# ``torch.float32``, do the following instead:
#
# .. code-block:: python
#
# calculator = torchani.models.ANI1ccx().ase(dtype=torch.float32)
###############################################################################
# Now let's set the calculator for ``atoms``:
atoms
.
set_calculator
(
calculator
)
###############################################################################
...
...
tests/test_ase.py
View file @
5bac472d
...
...
@@ -17,7 +17,7 @@ tol = 5e-5
def
get_numeric_force
(
atoms
,
eps
):
fn
=
torch
.
zeros
((
len
(
atoms
),
3
))
fn
=
torch
.
zeros
((
len
(
atoms
),
3
)
,
dtype
=
torch
.
double
)
for
i
in
range
(
len
(
atoms
)):
for
j
in
range
(
3
):
fn
[
i
,
j
]
=
numeric_force
(
atoms
,
i
,
j
,
eps
)
...
...
@@ -55,7 +55,7 @@ class TestASE(unittest.TestCase):
builtin
.
models
,
builtin
.
energy_shifter
)
default_neighborlist_calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
,
True
)
builtin
.
models
,
builtin
.
energy_shifter
,
_default_neighborlist
=
True
)
nnp
=
torch
.
nn
.
Sequential
(
builtin
.
aev_computer
,
builtin
.
models
,
...
...
@@ -104,7 +104,7 @@ class TestASE(unittest.TestCase):
builtin
.
models
,
builtin
.
energy_shifter
)
default_neighborlist_calculator
=
torchani
.
ase
.
Calculator
(
builtin
.
species
,
builtin
.
aev_computer
,
builtin
.
models
,
builtin
.
energy_shifter
,
True
)
builtin
.
models
,
builtin
.
energy_shifter
,
_default_neighborlist
=
True
)
atoms
.
set_calculator
(
calculator
)
dyn
=
Langevin
(
atoms
,
5
*
units
.
fs
,
50
*
units
.
kB
,
0.002
)
...
...
torchani/aev.py
View file @
5bac472d
...
...
@@ -119,7 +119,7 @@ def _terms_and_indices(Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA,
return
radial_terms
,
angular_terms
@
torch
.
jit
.
script
#
@torch.jit.script
def
default_neighborlist
(
species
,
coordinates
,
cutoff
):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
"""Default neighborlist computer"""
...
...
torchani/ase.py
View file @
5bac472d
...
...
@@ -107,6 +107,8 @@ class Calculator(ase.calculators.calculator.Calculator):
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
"""
...
...
@@ -114,24 +116,24 @@ class Calculator(ase.calculators.calculator.Calculator):
implemented_properties
=
[
'energy'
,
'forces'
]
def
__init__
(
self
,
species
,
aev_computer
,
model
,
energy_shifter
,
_default_neighborlist
=
False
):
dtype
=
torch
.
float64
,
_default_neighborlist
=
False
):
super
(
Calculator
,
self
).
__init__
()
self
.
_default_neighborlist
=
_default_neighborlist
self
.
species_to_tensor
=
utils
.
ChemicalSymbolsToInts
(
species
)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
self
.
aev_computer
=
copy
.
copy
(
aev_computer
)
self
.
model
=
model
self
.
energy_shifter
=
energy_shifter
self
.
aev_computer
=
copy
.
deep
copy
(
aev_computer
)
self
.
model
=
copy
.
deepcopy
(
model
)
self
.
energy_shifter
=
copy
.
deepcopy
(
energy_shifter
)
self
.
device
=
self
.
aev_computer
.
EtaR
.
device
self
.
dtype
=
self
.
aev_computer
.
EtaR
.
dtype
self
.
dtype
=
dtype
self
.
whole
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
model
,
self
.
energy_shifter
)
)
.
to
(
dtype
)
def
calculate
(
self
,
atoms
=
None
,
properties
=
[
'energy'
],
system_changes
=
ase
.
calculators
.
calculator
.
all_changes
):
...
...
torchani/models.py
View file @
5bac472d
...
...
@@ -51,12 +51,13 @@ class BuiltinModels(torch.nn.Module):
self
.
energy_shifter
)
def
ase
():
def
ase
(
**
kwargs
):
from
.
import
ase
return
ase
.
Calculator
(
self
.
builtins
.
species
,
self
.
aev_computer
,
self
.
neural_networks
[
index
],
self
.
energy_shifter
)
self
.
energy_shifter
,
**
kwargs
)
ret
.
ase
=
ase
ret
.
species_to_tensor
=
self
.
builtins
.
consts
.
species_to_tensor
...
...
@@ -65,11 +66,12 @@ class BuiltinModels(torch.nn.Module):
def
__len__
(
self
):
return
len
(
self
.
neural_networks
)
def
ase
(
self
):
def
ase
(
self
,
**
kwargs
):
"""Get an ASE Calculator using this model"""
from
.
import
ase
return
ase
.
Calculator
(
self
.
builtins
.
species
,
self
.
aev_computer
,
self
.
neural_networks
,
self
.
energy_shifter
)
self
.
neural_networks
,
self
.
energy_shifter
,
**
kwargs
)
class
ANI1x
(
BuiltinModels
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment