Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
965efee2
Unverified
Commit
965efee2
authored
Jul 28, 2018
by
Gao, Xiang
Committed by
GitHub
Jul 28, 2018
Browse files
helper function to create dataloader, helper module that handle batch (#27)
parent
194e88ff
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
91 additions
and
11 deletions
+91
-11
tests/test_batch.py
tests/test_batch.py
+32
-0
tests/test_data.py
tests/test_data.py
+1
-1
torchani/data.py
torchani/data.py
+40
-9
torchani/models/__init__.py
torchani/models/__init__.py
+2
-1
torchani/models/batch.py
torchani/models/batch.py
+16
-0
No files found.
tests/test_batch.py
0 → 100644
View file @
965efee2
import
sys
if
sys
.
version_info
.
major
>=
3
:
import
os
import
unittest
import
torch
import
torchani
import
torchani.data
import
itertools
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
'dataset'
)
chunksize
=
32
batch_chunks
=
32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestBatch
(
unittest
.
TestCase
):
def
testBatchLoadAndInference
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
for
batch_input
,
batch_output
in
itertools
.
islice
(
loader
,
10
):
batch_output_
=
batch_nnp
(
batch_input
).
squeeze
()
self
.
assertListEqual
(
list
(
batch_output_
.
shape
),
list
(
batch_output
[
'energies'
].
shape
))
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_data.py
View file @
965efee2
...
@@ -13,7 +13,7 @@ if sys.version_info.major >= 3:
...
@@ -13,7 +13,7 @@ if sys.version_info.major >= 3:
def
_test_chunksize
(
self
,
chunksize
):
def
_test_chunksize
(
self
,
chunksize
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
for
i
in
ds
:
for
i
in
ds
:
self
.
assertLessEqual
(
i
[
0
].
shape
[
0
],
chunksize
)
self
.
assertLessEqual
(
i
[
'coordinates'
].
shape
[
0
],
chunksize
)
def
testChunk64
(
self
):
def
testChunk64
(
self
):
self
.
_test_chunksize
(
64
)
self
.
_test_chunksize
(
64
)
...
...
torchani/data.py
View file @
965efee2
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
,
DataLoader
from
os.path
import
join
,
isfile
,
isdir
from
os.path
import
join
,
isfile
,
isdir
from
os
import
listdir
from
os
import
listdir
from
.pyanitools
import
anidataloader
from
.pyanitools
import
anidataloader
...
@@ -7,8 +7,13 @@ import torch
...
@@ -7,8 +7,13 @@ import torch
class
ANIDataset
(
Dataset
):
class
ANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
chunk_size
,
randomize_chunk
=
True
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
]):
super
(
ANIDataset
,
self
).
__init__
()
super
(
ANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
chunks_size
=
chunk_size
self
.
shuffle
=
shuffle
self
.
properties
=
properties
# get name of files storing data
# get name of files storing data
files
=
[]
files
=
[]
...
@@ -26,11 +31,14 @@ class ANIDataset(Dataset):
...
@@ -26,11 +31,14 @@ class ANIDataset(Dataset):
chunks
=
[]
chunks
=
[]
for
f
in
files
:
for
f
in
files
:
for
m
in
anidataloader
(
f
):
for
m
in
anidataloader
(
f
):
xyz
=
torch
.
from_numpy
(
m
[
'coordinates'
])
full
=
{
conformations
=
xyz
.
shape
[
0
]
'coordinates'
:
torch
.
from_numpy
(
m
[
'coordinates'
])
energies
=
torch
.
from_numpy
(
m
[
'energies'
])
}
conformations
=
full
[
'coordinates'
].
shape
[
0
]
for
i
in
properties
:
full
[
i
]
=
torch
.
from_numpy
(
m
[
i
])
species
=
m
[
'species'
]
species
=
m
[
'species'
]
if
randomize_chunk
:
if
shuffle
:
indices
=
torch
.
randperm
(
conformations
)
indices
=
torch
.
randperm
(
conformations
)
else
:
else
:
indices
=
torch
.
arange
(
conformations
,
dtype
=
torch
.
int64
)
indices
=
torch
.
arange
(
conformations
,
dtype
=
torch
.
int64
)
...
@@ -39,9 +47,11 @@ class ANIDataset(Dataset):
...
@@ -39,9 +47,11 @@ class ANIDataset(Dataset):
chunk_start
=
i
*
chunk_size
chunk_start
=
i
*
chunk_size
chunk_end
=
min
(
chunk_start
+
chunk_size
,
conformations
)
chunk_end
=
min
(
chunk_start
+
chunk_size
,
conformations
)
chunk_indices
=
indices
[
chunk_start
:
chunk_end
]
chunk_indices
=
indices
[
chunk_start
:
chunk_end
]
chunk_xyz
=
xyz
.
index_select
(
0
,
chunk_indices
)
chunk
=
{}
chunk_energies
=
energies
.
index_select
(
0
,
chunk_indices
)
for
j
in
full
:
chunks
.
append
((
chunk_xyz
,
chunk_energies
,
species
))
chunk
[
j
]
=
full
[
j
].
index_select
(
0
,
chunk_indices
)
chunk
[
'species'
]
=
species
chunks
.
append
(
chunk
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
...
@@ -49,3 +59,24 @@ class ANIDataset(Dataset):
...
@@ -49,3 +59,24 @@ class ANIDataset(Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
chunks
)
return
len
(
self
.
chunks
)
def
_collate
(
batch
):
input_keys
=
[
'coordinates'
,
'species'
]
inputs
=
[{
k
:
i
[
k
]
for
k
in
input_keys
}
for
i
in
batch
]
outputs
=
{}
for
i
in
batch
:
for
j
in
i
:
if
j
in
input_keys
:
continue
if
j
not
in
outputs
:
outputs
[
j
]
=
[]
outputs
[
j
].
append
(
i
[
j
])
for
i
in
outputs
:
outputs
[
i
]
=
torch
.
cat
(
outputs
[
i
])
return
inputs
,
outputs
def
dataloader
(
dataset
,
batch_chunks
,
**
kwargs
):
return
DataLoader
(
dataset
,
batch_chunks
,
dataset
.
shuffle
,
collate_fn
=
_collate
,
**
kwargs
)
torchani/models/__init__.py
View file @
965efee2
from
.custom
import
CustomModel
from
.custom
import
CustomModel
from
.neurochem_nnp
import
NeuroChemNNP
from
.neurochem_nnp
import
NeuroChemNNP
from
.batch
import
BatchModel
__all__
=
[
'CustomModel'
,
'NeuroChemNNP'
]
__all__
=
[
'CustomModel'
,
'NeuroChemNNP'
,
'BatchModel'
]
torchani/models/batch.py
0 → 100644
View file @
965efee2
import
torch
class
BatchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
(
BatchModel
,
self
).
__init__
()
self
.
model
=
model
def
forward
(
self
,
batch
):
results
=
[]
for
i
in
batch
:
coordinates
=
i
[
'coordinates'
]
species
=
i
[
'species'
]
results
.
append
(
self
.
model
(
coordinates
,
species
))
return
torch
.
cat
(
results
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment