Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
2ec2fb6d
Unverified
Commit
2ec2fb6d
authored
May 24, 2019
by
Gao, Xiang
Committed by
GitHub
May 24, 2019
Browse files
Modify dataset API to allow atomic properties (#231)
parent
4f63c32d
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
210 additions
and
171 deletions
+210
-171
docs/api.rst
docs/api.rst
+1
-1
tests/test_aev.py
tests/test_aev.py
+3
-3
tests/test_data.py
tests/test_data.py
+20
-9
tests/test_energies.py
tests/test_energies.py
+3
-3
tests/test_forces.py
tests/test_forces.py
+3
-4
tests/test_padding.py
tests/test_padding.py
+54
-46
torchani/data/__init__.py
torchani/data/__init__.py
+90
-62
torchani/utils.py
torchani/utils.py
+36
-43
No files found.
docs/api.rst
View file @
2ec2fb6d
...
...
@@ -34,7 +34,7 @@ Utilities
.. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_
coordinat
es
.. autofunction:: torchani.utils.pad_
atomic_properti
es
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
...
...
tests/test_aev.py
View file @
2ec2fb6d
...
...
@@ -113,11 +113,11 @@ class TestAEV(unittest.TestCase):
species
=
self
.
transform
(
species
)
radial
=
self
.
transform
(
radial
)
angular
=
self
.
transform
(
angular
)
species_coordinates
.
append
(
(
species
,
coordinates
)
)
species_coordinates
.
append
(
{
'species'
:
species
,
'
coordinates
'
:
coordinates
}
)
radial_angular
.
append
((
radial
,
angular
))
species
,
coordinates
=
torchani
.
utils
.
pad_
coordinat
es
(
species
_
coordinates
=
torchani
.
utils
.
pad_
atomic_properti
es
(
species_coordinates
)
_
,
aev
=
self
.
aev_computer
((
species
,
coordinates
))
_
,
aev
=
self
.
aev_computer
((
species
_coordinates
[
'species'
],
species_coordinates
[
'coordinates'
]
))
start
=
0
for
expected_radial
,
expected_angular
in
radial_angular
:
conformations
=
expected_radial
.
shape
[
0
]
...
...
tests/test_data.py
View file @
2ec2fb6d
...
...
@@ -30,16 +30,20 @@ class TestData(unittest.TestCase):
coordinates2
=
torch
.
randn
(
2
,
8
,
3
)
species3
=
torch
.
randint
(
4
,
(
10
,
20
),
dtype
=
torch
.
long
)
coordinates3
=
torch
.
randn
(
10
,
20
,
3
)
species
,
coordinates
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
(
species3
,
coordinates3
)
,
species
_
coordinates
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
{
'species'
:
species3
,
'coordinates'
:
coordinates3
}
,
])
species
=
species_coordinates
[
'species'
]
coordinates
=
species_coordinates
[
'coordinates'
]
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
chunks
=
torchani
.
data
.
split_batch
(
natoms
,
species
_
coordinates
)
start
=
0
last
=
None
for
s
,
c
in
chunks
:
for
chunk
in
chunks
:
s
=
chunk
[
'species'
]
c
=
chunk
[
'coordinates'
]
n
=
(
s
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
if
last
is
not
None
:
self
.
assertNotEqual
(
last
[
-
1
],
n
[
0
])
...
...
@@ -47,19 +51,26 @@ class TestData(unittest.TestCase):
self
.
assertGreater
(
conformations
,
0
)
s_
=
species
[
start
:(
start
+
conformations
),
...]
c_
=
coordinates
[
start
:(
start
+
conformations
),
...]
s_
,
c_
=
torchani
.
utils
.
strip_redundant_padding
(
s_
,
c_
)
sc
=
torchani
.
utils
.
strip_redundant_padding
({
'species'
:
s_
,
'coordinates'
:
c_
})
s_
=
sc
[
'species'
]
c_
=
sc
[
'coordinates'
]
self
.
_assertTensorEqual
(
s
,
s_
)
self
.
_assertTensorEqual
(
c
,
c_
)
start
+=
conformations
s
,
c
=
torchani
.
utils
.
pad_coordinates
(
chunks
)
sc
=
torchani
.
utils
.
pad_atomic_properties
(
chunks
)
s
=
sc
[
'species'
]
c
=
sc
[
'coordinates'
]
self
.
_assertTensorEqual
(
s
,
species
)
self
.
_assertTensorEqual
(
c
,
coordinates
)
def
testTensorShape
(
self
):
for
i
in
self
.
ds
:
input_
,
output
=
i
species
,
coordinates
=
torchani
.
utils
.
pad_coordinates
(
input_
)
input_
=
[{
'species'
:
x
[
0
],
'coordinates'
:
x
[
1
]}
for
x
in
input_
]
species_coordinates
=
torchani
.
utils
.
pad_atomic_properties
(
input_
)
species
=
species_coordinates
[
'species'
]
coordinates
=
species_coordinates
[
'coordinates'
]
energies
=
output
[
'energies'
]
self
.
assertEqual
(
len
(
species
.
shape
),
2
)
self
.
assertLessEqual
(
species
.
shape
[
0
],
batch_size
)
...
...
tests/test_energies.py
View file @
2ec2fb6d
...
...
@@ -89,12 +89,12 @@ class TestEnergies(unittest.TestCase):
coordinates
=
self
.
transform
(
coordinates
)
species
=
self
.
transform
(
species
)
e
=
self
.
transform
(
e
)
species_coordinates
.
append
(
(
species
,
coordinates
)
)
species_coordinates
.
append
(
{
'species'
:
species
,
'
coordinates
'
:
coordinates
}
)
energies
.
append
(
e
)
species
,
coordinates
=
torchani
.
utils
.
pad_
coordinat
es
(
species
_
coordinates
=
torchani
.
utils
.
pad_
atomic_properti
es
(
species_coordinates
)
energies
=
torch
.
cat
(
energies
)
_
,
energies_
=
self
.
model
((
species
,
coordinates
))
_
,
energies_
=
self
.
model
((
species
_coordinates
[
'species'
],
species_coordinates
[
'coordinates'
]
))
max_diff
=
(
energies
-
energies_
).
abs
().
max
().
item
()
self
.
assertLess
(
max_diff
,
self
.
tolerance
)
...
...
tests/test_forces.py
View file @
2ec2fb6d
...
...
@@ -55,11 +55,10 @@ class TestForce(unittest.TestCase):
species
=
self
.
transform
(
species
)
forces
=
self
.
transform
(
forces
)
coordinates
.
requires_grad_
(
True
)
species_coordinates
.
append
((
species
,
coordinates
))
coordinates_forces
.
append
((
coordinates
,
forces
))
species
,
coordinates
=
torchani
.
utils
.
pad_coordinates
(
species_coordinates
.
append
({
'species'
:
species
,
'coordinates'
:
coordinates
})
species_coordinates
=
torchani
.
utils
.
pad_atomic_properties
(
species_coordinates
)
_
,
energies
=
self
.
model
((
species
,
coordinates
))
_
,
energies
=
self
.
model
((
species
_coordinates
[
'species'
],
species_coordinates
[
'coordinates'
]
))
energies
=
energies
.
sum
()
for
coordinates
,
forces
in
coordinates_forces
:
derivative
=
torch
.
autograd
.
grad
(
energies
,
coordinates
,
...
...
tests/test_padding.py
View file @
2ec2fb6d
...
...
@@ -6,17 +6,17 @@ import torchani
class
TestPaddings
(
unittest
.
TestCase
):
def
testVectorSpecies
(
self
):
species1
=
torch
.
LongT
ensor
([
0
,
2
,
3
,
1
])
species1
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
]
]
)
coordinates1
=
torch
.
zeros
(
5
,
4
,
3
)
species2
=
torch
.
LongT
ensor
([
3
,
2
,
0
,
1
,
0
])
species2
=
torch
.
t
ensor
([
[
3
,
2
,
0
,
1
,
0
]
]
)
coordinates2
=
torch
.
zeros
(
2
,
5
,
3
)
species
,
coordinat
es
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
atomic_properti
es
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
])
self
.
assertEqual
(
species
.
shape
[
0
],
7
)
self
.
assertEqual
(
species
.
shape
[
1
],
5
)
expected_species
=
torch
.
LongT
ensor
([
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
0
],
7
)
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
1
],
5
)
expected_species
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
...
...
@@ -25,21 +25,21 @@ class TestPaddings(unittest.TestCase):
[
3
,
2
,
0
,
1
,
0
],
[
3
,
2
,
0
,
1
,
0
],
])
self
.
assertEqual
((
species
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
coordinates
.
abs
().
max
().
item
(),
0
)
self
.
assertEqual
((
atomic_properties
[
'
species
'
]
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
atomic_properties
[
'
coordinates
'
]
.
abs
().
max
().
item
(),
0
)
def
testTensorShape1NSpecies
(
self
):
species1
=
torch
.
LongT
ensor
([[
0
,
2
,
3
,
1
]])
species1
=
torch
.
t
ensor
([[
0
,
2
,
3
,
1
]])
coordinates1
=
torch
.
zeros
(
5
,
4
,
3
)
species2
=
torch
.
LongT
ensor
([
3
,
2
,
0
,
1
,
0
])
species2
=
torch
.
t
ensor
([
[
3
,
2
,
0
,
1
,
0
]
]
)
coordinates2
=
torch
.
zeros
(
2
,
5
,
3
)
species
,
coordinat
es
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
atomic_properti
es
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
])
self
.
assertEqual
(
species
.
shape
[
0
],
7
)
self
.
assertEqual
(
species
.
shape
[
1
],
5
)
expected_species
=
torch
.
LongT
ensor
([
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
0
],
7
)
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
1
],
5
)
expected_species
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
...
...
@@ -48,11 +48,11 @@ class TestPaddings(unittest.TestCase):
[
3
,
2
,
0
,
1
,
0
],
[
3
,
2
,
0
,
1
,
0
],
])
self
.
assertEqual
((
species
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
coordinates
.
abs
().
max
().
item
(),
0
)
self
.
assertEqual
((
atomic_properties
[
'
species
'
]
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
atomic_properties
[
'
coordinates
'
]
.
abs
().
max
().
item
(),
0
)
def
testTensorSpecies
(
self
):
species1
=
torch
.
LongT
ensor
([
species1
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
...
...
@@ -60,15 +60,15 @@ class TestPaddings(unittest.TestCase):
[
0
,
2
,
3
,
1
],
])
coordinates1
=
torch
.
zeros
(
5
,
4
,
3
)
species2
=
torch
.
LongT
ensor
([
3
,
2
,
0
,
1
,
0
])
species2
=
torch
.
t
ensor
([
[
3
,
2
,
0
,
1
,
0
]
]
)
coordinates2
=
torch
.
zeros
(
2
,
5
,
3
)
species
,
coordinat
es
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
atomic_properti
es
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
])
self
.
assertEqual
(
species
.
shape
[
0
],
7
)
self
.
assertEqual
(
species
.
shape
[
1
],
5
)
expected_species
=
torch
.
LongT
ensor
([
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
0
],
7
)
self
.
assertEqual
(
atomic_properties
[
'
species
'
]
.
shape
[
1
],
5
)
expected_species
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
...
...
@@ -77,22 +77,22 @@ class TestPaddings(unittest.TestCase):
[
3
,
2
,
0
,
1
,
0
],
[
3
,
2
,
0
,
1
,
0
],
])
self
.
assertEqual
((
species
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
coordinates
.
abs
().
max
().
item
(),
0
)
self
.
assertEqual
((
atomic_properties
[
'
species
'
]
-
expected_species
).
abs
().
max
().
item
(),
0
)
self
.
assertEqual
(
atomic_properties
[
'
coordinates
'
]
.
abs
().
max
().
item
(),
0
)
def
testPadSpecies
(
self
):
species1
=
torch
.
LongT
ensor
([
species1
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
[
0
,
2
,
3
,
1
],
])
species2
=
torch
.
LongT
ensor
([
3
,
2
,
0
,
1
,
0
]).
expand
(
2
,
5
)
species2
=
torch
.
t
ensor
([
[
3
,
2
,
0
,
1
,
0
]
]
).
expand
(
2
,
5
)
species
=
torchani
.
utils
.
pad
([
species1
,
species2
])
self
.
assertEqual
(
species
.
shape
[
0
],
7
)
self
.
assertEqual
(
species
.
shape
[
1
],
5
)
expected_species
=
torch
.
LongT
ensor
([
expected_species
=
torch
.
t
ensor
([
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
[
0
,
2
,
3
,
1
,
-
1
],
...
...
@@ -104,9 +104,9 @@ class TestPaddings(unittest.TestCase):
self
.
assertEqual
((
species
-
expected_species
).
abs
().
max
().
item
(),
0
)
def
testPresentSpecies
(
self
):
species
=
torch
.
LongT
ensor
([
0
,
1
,
1
,
0
,
3
,
7
,
-
1
,
-
1
])
species
=
torch
.
t
ensor
([
0
,
1
,
1
,
0
,
3
,
7
,
-
1
,
-
1
])
present_species
=
torchani
.
utils
.
present_species
(
species
)
expected
=
torch
.
LongT
ensor
([
0
,
1
,
3
,
7
])
expected
=
torch
.
t
ensor
([
0
,
1
,
3
,
7
])
self
.
assertEqual
((
expected
-
present_species
).
abs
().
max
().
item
(),
0
)
...
...
@@ -120,23 +120,31 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1
=
torch
.
randn
(
5
,
4
,
3
)
species2
=
torch
.
randint
(
4
,
(
2
,
5
),
dtype
=
torch
.
long
)
coordinates2
=
torch
.
randn
(
2
,
5
,
3
)
species12
,
coordinat
es12
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
atomic_properti
es12
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
])
species12
=
atomic_properties12
[
'species'
]
coordinates12
=
atomic_properties12
[
'coordinates'
]
species3
=
torch
.
randint
(
4
,
(
2
,
10
),
dtype
=
torch
.
long
)
coordinates3
=
torch
.
randn
(
2
,
10
,
3
)
species123
,
coordinat
es123
=
torchani
.
utils
.
pad_
coordinat
es
([
(
species1
,
coordinates1
)
,
(
species2
,
coordinates2
)
,
(
species3
,
coordinates3
)
,
atomic_properti
es123
=
torchani
.
utils
.
pad_
atomic_properti
es
([
{
'species'
:
species1
,
'coordinates'
:
coordinates1
}
,
{
'species'
:
species2
,
'coordinates'
:
coordinates2
}
,
{
'species'
:
species3
,
'coordinates'
:
coordinates3
}
,
])
species1_
,
coordinates1_
=
torchani
.
utils
.
strip_redundant_padding
(
species123
[:
5
,
...],
coordinates123
[:
5
,
...])
species123
=
atomic_properties123
[
'species'
]
coordinates123
=
atomic_properties123
[
'coordinates'
]
species_coordinates1_
=
torchani
.
utils
.
strip_redundant_padding
(
{
'species'
:
species123
[:
5
,
...],
'coordinates'
:
coordinates123
[:
5
,
...]})
species1_
=
species_coordinates1_
[
'species'
]
coordinates1_
=
species_coordinates1_
[
'coordinates'
]
self
.
_assertTensorEqual
(
species1_
,
species1
)
self
.
_assertTensorEqual
(
coordinates1_
,
coordinates1
)
species12_
,
coordinates12_
=
torchani
.
utils
.
strip_redundant_padding
(
species123
[:
7
,
...],
coordinates123
[:
7
,
...])
species_coordinates12_
=
torchani
.
utils
.
strip_redundant_padding
(
{
'species'
:
species123
[:
7
,
...],
'coordinates'
:
coordinates123
[:
7
,
...]})
species12_
=
species_coordinates12_
[
'species'
]
coordinates12_
=
species_coordinates12_
[
'coordinates'
]
self
.
_assertTensorEqual
(
species12_
,
species12
)
self
.
_assertTensorEqual
(
coordinates12_
,
coordinates12
)
...
...
torchani/data/__init__.py
View file @
2ec2fb6d
...
...
@@ -21,21 +21,22 @@ def chunk_counts(counts, split):
for
i
in
split
:
count_chunks
.
append
(
counts
[
start
:
i
])
start
=
i
chunk_
conformation
s
=
[
sum
([
y
[
1
]
for
y
in
x
])
for
x
in
count_chunks
]
chunk_
molecule
s
=
[
sum
([
y
[
1
]
for
y
in
x
])
for
x
in
count_chunks
]
chunk_maxatoms
=
[
x
[
-
1
][
0
]
for
x
in
count_chunks
]
return
chunk_
conformation
s
,
chunk_maxatoms
return
chunk_
molecule
s
,
chunk_maxatoms
def
split_cost
(
counts
,
split
):
split_min_cost
=
40000
cost
=
0
chunk_
conformation
s
,
chunk_maxatoms
=
chunk_counts
(
counts
,
split
)
for
conformation
s
,
maxatoms
in
zip
(
chunk_
conformation
s
,
chunk_maxatoms
):
cost
+=
max
(
conformation
s
*
maxatoms
**
2
,
split_min_cost
)
chunk_
molecule
s
,
chunk_maxatoms
=
chunk_counts
(
counts
,
split
)
for
molecule
s
,
maxatoms
in
zip
(
chunk_
molecule
s
,
chunk_maxatoms
):
cost
+=
max
(
molecule
s
*
maxatoms
**
2
,
split_min_cost
)
return
cost
def
split_batch
(
natoms
,
species
,
coordinates
):
def
split_batch
(
natoms
,
atomic_properties
):
# count number of conformation by natoms
natoms
=
natoms
.
tolist
()
counts
=
[]
...
...
@@ -47,6 +48,7 @@ def split_batch(natoms, species, coordinates):
counts
[
-
1
][
1
]
+=
1
else
:
counts
.
append
([
i
,
1
])
# find best split using greedy strategy
split
=
[]
cost
=
split_cost
(
counts
,
split
)
...
...
@@ -66,19 +68,21 @@ def split_batch(natoms, species, coordinates):
if
improved
:
split
=
cycle_split
cost
=
cycle_cost
# do split
start
=
0
species_coordinates
=
[]
chunk_conformations
,
_
=
chunk_counts
(
counts
,
split
)
for
i
in
chunk_conformations
:
s
=
species
end
=
start
+
i
s
=
species
[
start
:
end
,
...]
c
=
coordinates
[
start
:
end
,
...]
s
,
c
=
utils
.
strip_redundant_padding
(
s
,
c
)
species_coordinates
.
append
((
s
,
c
))
start
=
end
return
species_coordinates
chunk_molecules
,
_
=
chunk_counts
(
counts
,
split
)
num_chunks
=
None
for
k
in
atomic_properties
:
atomic_properties
[
k
]
=
atomic_properties
[
k
].
split
(
chunk_molecules
)
if
num_chunks
is
None
:
num_chunks
=
len
(
atomic_properties
[
k
])
else
:
assert
num_chunks
==
len
(
atomic_properties
[
k
])
chunks
=
[]
for
i
in
range
(
num_chunks
):
chunk
=
{
k
:
atomic_properties
[
k
][
i
]
for
k
in
atomic_properties
}
chunks
.
append
(
utils
.
strip_redundant_padding
(
chunk
))
return
chunks
class
BatchedANIDataset
(
Dataset
):
...
...
@@ -118,13 +122,24 @@ class BatchedANIDataset(Dataset):
batch_size (int): Number of different 3D structures in a single
minibatch.
shuffle (bool): Whether to shuffle the whole dataset.
properties (list): List of keys in the dataset to be loaded.
``'species'`` and ``'coordinates'`` are always loaded and need not
to be specified here.
properties (list): List of keys of `molecular` properties in the
dataset to be loaded. Here `molecular` means, no matter the number
of atoms that property always have fixed size, i.e. the tensor
shape of molecular properties should be (molecule, ...). An example
of molecular property is the molecular energies. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
atomic_properties (list): List of keys of `atomic` properties in the
dataset to be loaded. Here `atomic` means, the size of property
is proportional to the number of atoms in the molecule, i.e. the
tensor shape of atomic properties should be (molecule, atoms, ...).
An example of atomic property is the forces. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
transform (list): List of :class:`collections.abc.Callable` that
transform the data. Callables must take
species, coordinat
es,
and
properties
of the whole dataset
as arguments, and return
the transformed species, coordinat
es
,
and properties.
transform the data. Callables must take
atomic properti
es,
properties as arguments, and return
the transformed atomic
properti
es and properties.
dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating.
...
...
@@ -134,7 +149,7 @@ class BatchedANIDataset(Dataset):
"""
def
__init__
(
self
,
path
,
species_tensor_converter
,
batch_size
,
shuffle
=
True
,
properties
=
[
'energies'
]
,
transform
=
(),
shuffle
=
True
,
properties
=
(
'energies'
,),
atomic_properties
=
()
,
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
default_device
):
super
(
BatchedANIDataset
,
self
).
__init__
()
self
.
properties
=
properties
...
...
@@ -153,68 +168,81 @@ class BatchedANIDataset(Dataset):
raise
ValueError
(
'Bad path'
)
# load full dataset
species_coordinat
es
=
[]
atomic_properti
es
_
=
[]
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
for
f
in
files
:
for
m
in
anidataloader
(
f
):
s
=
species_tensor_converter
(
m
[
'species'
])
c
=
torch
.
from_numpy
(
m
[
'coordinates'
]).
to
(
torch
.
double
)
species_coordinates
.
append
((
s
,
c
))
atomic_properties_
.
append
(
dict
(
species
=
species_tensor_converter
(
m
[
'species'
]).
unsqueeze
(
0
),
**
{
k
:
torch
.
from_numpy
(
m
[
k
]).
to
(
torch
.
double
)
for
k
in
[
'coordinates'
]
+
list
(
atomic_properties
)
}
))
for
i
in
properties
:
p
=
torch
.
from_numpy
(
m
[
i
]).
to
(
torch
.
double
)
properties
[
i
].
append
(
p
)
species
,
coordinates
=
utils
.
pad_coordinates
(
species_coordinat
es
)
atomic_properties
=
utils
.
pad_atomic_properties
(
atomic_properti
es
_
)
for
i
in
properties
:
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
# shuffle if required
conformations
=
coordinates
.
shape
[
0
]
molecules
=
atomic_properties
[
'species'
]
.
shape
[
0
]
if
shuffle
:
indices
=
torch
.
randperm
(
conformations
)
species
=
species
.
index_select
(
0
,
indices
)
coordinates
=
coordinates
.
index_select
(
0
,
indices
)
indices
=
torch
.
randperm
(
molecules
)
for
i
in
properties
:
properties
[
i
]
=
properties
[
i
].
index_select
(
0
,
indices
)
for
i
in
atomic_properties
:
atomic_properties
[
i
]
=
atomic_properties
[
i
].
index_select
(
0
,
indices
)
# do transformations on data
for
t
in
transform
:
species
,
coordinates
,
properties
=
t
(
species
,
coordinates
,
properties
)
atomic_properties
,
properties
=
t
(
atomic_properties
,
properties
)
# convert to desired dtype
species
=
species
coordinates
=
coordinates
.
to
(
dtype
)
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
to
(
dtype
)
for
k
in
atomic_properties
:
if
k
==
'species'
:
continue
atomic_properties
[
k
]
=
atomic_properties
[
k
].
to
(
dtype
)
# split into minibatches, and strip redundant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
# split into minibatches
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
split
(
batch_size
)
for
k
in
atomic_properties
:
atomic_properties
[
k
]
=
atomic_properties
[
k
].
split
(
batch_size
)
# further split batch into chunks and strip redundant padding
self
.
batches
=
[]
num_batches
=
(
molecules
+
batch_size
-
1
)
//
batch_size
for
i
in
range
(
num_batches
):
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
natoms_batch
=
natoms
[
start
:
end
]
batch_properties
=
{
k
:
v
[
i
]
for
k
,
v
in
properties
.
items
()}
batch_atomic_properties
=
{
k
:
v
[
i
]
for
k
,
v
in
atomic_properties
.
items
()}
species
=
batch_atomic_properties
[
'species'
]
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
# sort batch by number of atoms to prepare for splitting
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
.
index_select
(
0
,
indices
)
properties_batch
=
{
k
:
properties
[
k
][
start
:
end
,
...].
index_select
(
0
,
indices
)
.
to
(
self
.
device
)
for
k
in
properties
}
# further split batch into chunks
species_coordinates
=
split_batch
(
natoms_batch
,
species_batch
,
coordinates_batch
)
batch
=
species_coordinates
,
properties_batch
batches
.
append
(
batch
)
self
.
batches
=
batches
natoms
,
indices
=
natoms
.
sort
()
for
k
in
batch_properties
:
batch_properties
[
k
]
=
batch_properties
[
k
].
index_select
(
0
,
indices
)
for
k
in
batch_atomic_properties
:
batch_atomic_properties
[
k
]
=
batch_atomic_properties
[
k
].
index_select
(
0
,
indices
)
batch_atomic_properties
=
split_batch
(
natoms
,
batch_atomic_properties
)
self
.
batches
.
append
((
batch_atomic_properties
,
batch_properties
))
def
__getitem__
(
self
,
idx
):
species_coordinates
,
properties
=
self
.
batches
[
idx
]
species_coordinates
=
[(
s
.
to
(
self
.
device
),
c
.
to
(
self
.
device
))
for
s
,
c
in
species_coordinates
]
atomic_properties
,
properties
=
self
.
batches
[
idx
]
atomic_properties
,
properties
=
atomic_properties
.
copy
(),
properties
.
copy
()
species_coordinates
=
[]
for
chunk
in
atomic_properties
:
for
k
in
chunk
:
chunk
[
k
]
=
chunk
[
k
].
to
(
self
.
device
)
species_coordinates
.
append
((
chunk
[
'species'
],
chunk
[
'coordinates'
]))
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
to
(
self
.
device
)
properties
[
'atomic'
]
=
atomic_properties
return
species_coordinates
,
properties
def
__len__
(
self
):
...
...
torchani/utils.py
View file @
2ec2fb6d
import
torch
import
torch.utils.data
import
math
from
collections
import
defaultdict
def
pad
(
species
):
...
...
@@ -30,41 +31,35 @@ def pad(species):
return
torch
.
cat
(
padded_species
)
def
pad_
coordinates
(
species_coordinates
):
"""Put
different species and coordinat
es together into single tensor.
def
pad_
atomic_properties
(
atomic_properties
,
padding_values
=
defaultdict
(
lambda
:
0.0
,
species
=-
1
)
):
"""Put
a sequence of atomic properti
es together into single tensor.
If the species and coordinates are from molecules of different number of
total atoms, then ghost atoms with atom type -1 and coordinate (0, 0, 0)
will be added to make it fit into the same shape.
Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs
are `{'species': padded_tensor, ...}`
Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of
pairs of species and coordinates. Species must be of shape
``(N, A)`` and coordinates must be of shape ``(N, A, 3)``, where
``N`` is the number of 3D structures, ``A`` is the number of atoms.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): Species, and
coordinates batched together.
atomic properties.
padding_values (dict): the value to fill to pad tensors to same size
"""
max_atoms
=
max
([
c
.
shape
[
1
]
for
_
,
c
in
species_coordinates
])
species
=
[
]
coord
inat
es
=
[]
for
s
,
c
in
species_coordinates
:
natoms
=
c
.
shape
[
1
]
if
len
(
s
.
shape
)
==
1
:
s
=
s
.
unsqueeze
(
0
)
if
natoms
<
max_atoms
:
pad
ding
=
torch
.
full
((
s
.
shape
[
0
],
max_atoms
-
natoms
),
-
1
,
dtype
=
torch
.
long
,
device
=
s
.
device
)
s
=
torch
.
cat
([
s
,
padding
],
dim
=
1
)
padding
=
torch
.
full
((
c
.
shape
[
0
],
max_atoms
-
natoms
,
3
),
0
,
dtype
=
c
.
dtype
,
device
=
c
.
device
)
c
=
torch
.
cat
([
c
,
padding
],
dim
=
1
)
s
=
s
.
expand
(
c
.
shape
[
0
],
max_atoms
)
species
.
append
(
s
)
coordinates
.
append
(
c
)
return
torch
.
cat
(
species
),
torch
.
cat
(
coordinates
)
keys
=
list
(
atomic_properties
[
0
])
anykey
=
keys
[
0
]
max_atoms
=
max
(
x
[
anykey
].
shape
[
1
]
for
x
in
at
omic_properties
)
padded
=
{
k
:
[]
for
k
in
keys
}
for
p
in
atomic_properties
:
num_molecules
=
max
(
v
.
shape
[
0
]
for
v
in
p
.
values
())
for
k
,
v
in
p
.
items
():
shape
=
list
(
v
.
shape
)
pad
atoms
=
max_atoms
-
shape
[
1
]
shape
[
1
]
=
padatoms
padding
=
v
.
new_full
(
shape
,
padding_values
[
k
]
)
v
=
torch
.
cat
([
v
,
padding
],
dim
=
1
)
if
v
.
shape
[
0
]
<
num_molecules
:
shape
=
list
(
v
.
shape
)
shape
[
0
]
=
num_molecules
v
=
v
.
expand
(
*
shape
)
padded
[
k
]
.
append
(
v
)
return
{
k
:
torch
.
cat
(
v
)
for
k
,
v
in
padded
.
items
()}
# @torch.jit.script
...
...
@@ -84,23 +79,20 @@ def present_species(species):
return
present_species
def
strip_redundant_padding
(
species
,
coordinat
es
):
def
strip_redundant_padding
(
atomic_properti
es
):
"""Strip trailing padding atoms.
Arguments:
species (:class:`torch.Tensor`): Long tensor of shape
``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
atomic_properties (dict): properties to strip
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
with redundant padding atoms stripped.
dict: same set of properties with redundant padding atoms stripped.
"""
species
=
atomic_properties
[
'species'
]
non_padding
=
(
species
>=
0
).
any
(
dim
=
0
).
nonzero
().
squeeze
()
species
=
species
.
index_select
(
1
,
non_padding
)
coordinates
=
coordinates
.
index_select
(
1
,
non_padding
)
return
species
,
coordinat
es
for
k
in
atomic_properties
:
atomic_properties
[
k
]
=
atomic_properties
[
k
]
.
index_select
(
1
,
non_padding
)
return
atomic_properti
es
def
map2central
(
cell
,
coordinates
,
pbc
):
...
...
@@ -170,15 +162,16 @@ class EnergyShifter(torch.nn.Module):
self_energies
[
species
==
-
1
]
=
0
return
self_energies
.
sum
(
dim
=
1
)
def
subtract_from_dataset
(
self
,
species
,
coordinat
es
,
properties
):
def
subtract_from_dataset
(
self
,
atomic_properti
es
,
properties
):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that
subtract self energies.
"""
species
=
atomic_properties
[
'species'
]
energies
=
properties
[
'energies'
]
device
=
energies
.
device
energies
=
energies
.
to
(
torch
.
double
)
-
self
.
sae
(
species
).
to
(
device
)
properties
[
'energies'
]
=
energies
return
species
,
coordinat
es
,
properties
return
atomic_properti
es
,
properties
def
forward
(
self
,
species_energies
):
"""(species, molecular energies)->(species, molecular energies + sae)
...
...
@@ -263,6 +256,6 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
return
wavenumbers
,
modes
__all__
=
[
'pad'
,
'pad_
coordinat
es'
,
'present_species'
,
'hessian'
,
__all__
=
[
'pad'
,
'pad_
atomic_properti
es'
,
'present_species'
,
'hessian'
,
'vibrational_analysis'
,
'strip_redundant_padding'
,
'ChemicalSymbolsToInts'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment