Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
327a9b20
Unverified
Commit
327a9b20
authored
Jul 24, 2018
by
Gao, Xiang
Committed by
GitHub
Jul 24, 2018
Browse files
rewrite dataloader API to improve performance and reduce code complexity (#18)
parent
f6ef4ebb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
325 deletions
+57
-325
tests/test_data.py
tests/test_data.py
+14
-146
torchani/__init__.py
torchani/__init__.py
+2
-3
torchani/data.py
torchani/data.py
+41
-176
No files found.
tests/test_data.py
View file @
327a9b20
import
sys
if
sys
.
version_info
.
major
>=
3
:
import
torchani
import
unittest
import
tempfile
import
os
import
torch
import
torchani.pyanitools
as
pyanitools
import
unittest
import
torchani.data
from
math
import
ceil
from
bisect
import
bisect
from
pickle
import
dump
,
load
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
d
at
aset_dir
=
os
.
path
.
join
(
path
,
'dataset'
)
p
at
h
=
os
.
path
.
join
(
path
,
'dataset'
)
class
TestDataset
(
unittest
.
TestCase
):
def
setUp
(
self
,
data_path
=
dataset_dir
):
self
.
data_path
=
data_path
self
.
ds
=
torchani
.
data
.
load_dataset
(
data_path
)
def
testLen
(
self
):
# compute data length using Dataset
l1
=
len
(
self
.
ds
)
# compute data lenght using pyanitools
l2
=
0
for
f
in
os
.
listdir
(
self
.
data_path
):
f
=
os
.
path
.
join
(
self
.
data_path
,
f
)
if
os
.
path
.
isfile
(
f
)
and
\
(
f
.
endswith
(
'.h5'
)
or
f
.
endswith
(
'.hdf5'
)):
for
j
in
pyanitools
.
anidataloader
(
f
):
l2
+=
j
[
'energies'
].
shape
[
0
]
# compute data length using iterator
l3
=
len
(
list
(
self
.
ds
))
# these lengths should match
self
.
assertEqual
(
l1
,
l2
)
self
.
assertEqual
(
l1
,
l3
)
def
testNumChunks
(
self
):
chunksize
=
64
# compute number of chunks using batch sampler
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
l1
=
len
(
bs
)
# compute number of chunks using pyanitools
l2
=
0
for
f
in
os
.
listdir
(
self
.
data_path
):
f
=
os
.
path
.
join
(
self
.
data_path
,
f
)
if
os
.
path
.
isfile
(
f
)
and
\
(
f
.
endswith
(
'.h5'
)
or
f
.
endswith
(
'.hdf5'
)):
for
j
in
pyanitools
.
anidataloader
(
f
):
conformations
=
j
[
'energies'
].
shape
[
0
]
l2
+=
ceil
(
conformations
/
chunksize
)
# compute number of chunks using iterator
l3
=
len
(
list
(
bs
))
# these lengths should match
self
.
assertEqual
(
l1
,
l2
)
self
.
assertEqual
(
l1
,
l3
)
def
testNumBatches
(
self
):
chunksize
=
64
batch_chunks
=
4
# compute number of batches using batch sampler
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
batch_chunks
)
l1
=
len
(
bs
)
# compute number of batches by simple math
bs2
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
l2
=
ceil
(
len
(
bs2
)
/
batch_chunks
)
# compute number of batches using iterator
l3
=
len
(
list
(
bs
))
# these lengths should match
self
.
assertEqual
(
l1
,
l2
)
self
.
assertEqual
(
l1
,
l3
)
def
testBatchSize1
(
self
):
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
1
,
1
)
self
.
assertEqual
(
len
(
bs
),
len
(
self
.
ds
))
def
testSplitSize
(
self
):
chunksize
=
64
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
chunks
=
len
(
bs
)
ds1
,
ds2
=
torchani
.
data
.
random_split
(
self
.
ds
,
[
200
,
chunks
-
200
],
chunksize
)
bs1
=
torchani
.
data
.
BatchSampler
(
ds1
,
chunksize
,
1
)
bs2
=
torchani
.
data
.
BatchSampler
(
ds2
,
chunksize
,
1
)
self
.
assertEqual
(
len
(
bs1
),
200
)
self
.
assertEqual
(
len
(
bs2
),
chunks
-
200
)
def
testSplitNoOverlap
(
self
):
chunksize
=
64
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
chunks
=
len
(
bs
)
ds1
,
ds2
=
torchani
.
data
.
random_split
(
self
.
ds
,
[
200
,
chunks
-
200
],
chunksize
)
indices1
=
ds1
.
dataset
.
indices
indices2
=
ds2
.
dataset
.
indices
self
.
assertEqual
(
len
(
indices1
),
len
(
ds1
))
self
.
assertEqual
(
len
(
indices2
),
len
(
ds2
))
self
.
assertEqual
(
len
(
indices1
),
len
(
set
(
indices1
)))
self
.
assertEqual
(
len
(
indices2
),
len
(
set
(
indices2
)))
self
.
assertEqual
(
len
(
self
.
ds
),
len
(
set
(
indices1
+
indices2
)))
def
_testMolSizes
(
self
,
ds
):
for
i
in
range
(
len
(
ds
)):
left
=
bisect
(
ds
.
cumulative_sizes
,
i
)
moli
=
ds
[
i
][
0
].
item
()
for
j
in
range
(
len
(
ds
)):
left2
=
bisect
(
ds
.
cumulative_sizes
,
j
)
molj
=
ds
[
j
][
0
].
item
()
if
left
==
left2
:
self
.
assertEqual
(
moli
,
molj
)
else
:
if
moli
==
molj
:
print
(
i
,
j
)
self
.
assertNotEqual
(
moli
,
molj
)
def
testMolSizes
(
self
):
chunksize
=
8
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
chunks
=
len
(
bs
)
ds1
,
ds2
=
torchani
.
data
.
random_split
(
self
.
ds
,
[
50
,
chunks
-
50
],
chunksize
)
self
.
_testMolSizes
(
ds1
)
def
testSaveLoad
(
self
):
chunksize
=
8
bs
=
torchani
.
data
.
BatchSampler
(
self
.
ds
,
chunksize
,
1
)
chunks
=
len
(
bs
)
ds1
,
ds2
=
torchani
.
data
.
random_split
(
self
.
ds
,
[
50
,
chunks
-
50
],
chunksize
)
tmpdir
=
tempfile
.
TemporaryDirectory
()
tmpdirname
=
tmpdir
.
name
filename
=
os
.
path
.
join
(
tmpdirname
,
'test.obj'
)
def
_test_chunksize
(
self
,
chunksize
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
for
i
in
ds
:
self
.
assertLessEqual
(
i
[
0
].
shape
[
0
],
chunksize
)
with
open
(
filename
,
'wb'
)
as
f
:
dump
(
ds1
,
f
)
def
testChunk64
(
self
)
:
self
.
_test_chunksize
(
64
)
with
open
(
filename
,
'rb'
)
as
f
:
ds1_loaded
=
load
(
f
)
def
testChunk128
(
self
)
:
self
.
_test_chunksize
(
128
)
self
.
assertEqual
(
len
(
ds1
),
len
(
ds1_loaded
))
self
.
assertListEqual
(
ds1
.
sizes
,
ds1_loaded
.
sizes
)
self
.
assertIsInstance
(
ds1_loaded
,
torchani
.
data
.
ANIDataset
)
def
testChunk32
(
self
):
self
.
_test_chunksize
(
32
)
for
i
in
range
(
len
(
ds1
)):
i1
=
ds1
[
i
]
i2
=
ds1_loaded
[
i
]
molid1
=
i1
[
0
].
item
()
molid2
=
i2
[
0
].
item
()
self
.
assertEqual
(
molid1
,
molid2
)
xyz1
=
i1
[
1
]
xyz2
=
i2
[
1
]
maxdiff
=
torch
.
max
(
torch
.
abs
(
xyz1
-
xyz2
)).
item
()
self
.
assertEqual
(
maxdiff
,
0
)
e1
=
i1
[
2
].
item
()
e2
=
i2
[
2
].
item
()
self
.
assertEqual
(
e1
,
e2
)
def
testChunk256
(
self
):
self
.
_test_chunksize
(
256
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchani/__init__.py
View file @
327a9b20
...
...
@@ -6,6 +6,5 @@ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
__all__
=
[
'SortedAEV'
,
'EnergyShifter'
,
'ModelOnAEV'
,
'PerSpeciesFromNeuroChem'
,
'data'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_dataset_dir'
,
'buildin_model_prefix'
,
'buildin_ensembles'
,
'default_dtype'
,
'default_device'
]
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_model_prefix'
,
'buildin_ensembles'
,
'default_dtype'
,
'default_device'
]
torchani/data.py
View file @
327a9b20
from
.pyanitools
import
anidataloader
from
os
import
listdir
from
torch.utils.data
import
Dataset
from
os.path
import
join
,
isfile
,
isdir
from
torch
import
tensor
,
full_like
,
long
from
torch.utils.data
import
Dataset
,
Subset
,
TensorDataset
,
ConcatDataset
from
torch.utils.data.dataloader
import
default_collate
from
math
import
ceil
from
.
import
default_dtype
from
random
import
shuffle
from
itertools
import
chain
,
accumulate
from
os
import
listdir
from
.pyanitools
import
anidataloader
import
torch
class
ANIDataset
(
Dataset
):
"""Dataset with extra information for ANI applications
Attributes
----------
dataset : Dataset
The dataset
sizes : sequence
Number of conformations for each molecule
cumulative_sizes : sequence
Cumulative sizes
"""
def
__init__
(
self
,
d
at
aset
,
size
s
,
species
):
def
__init__
(
self
,
p
at
h
,
chunk_
size
,
randomize_chunk
=
True
):
super
(
ANIDataset
,
self
).
__init__
()
self
.
dataset
=
dataset
self
.
sizes
=
sizes
self
.
cumulative_sizes
=
list
(
accumulate
(
sizes
))
self
.
species
=
species
def
__getitem__
(
self
,
idx
):
return
self
.
dataset
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
load_dataset
(
path
,
dtype
=
default_dtype
):
"""The returned dataset has cumulative_sizes and molecule_sizes"""
# get name of files storing data
files
=
[]
if
isdir
(
path
):
for
f
in
listdir
(
path
):
f
=
join
(
path
,
f
)
if
isfile
(
f
)
and
(
f
.
endswith
(
'.h5'
)
or
f
.
endswith
(
'.hdf5'
)):
files
.
append
(
f
)
elif
isfile
(
path
):
files
=
[
path
]
else
:
raise
ValueError
(
'Bad path'
)
# read tensors from file and build a dataset
species
=
[]
molecule_id
=
0
datasets
=
[]
for
f
in
files
:
for
m
in
anidataloader
(
f
):
coordinates
=
tensor
(
m
[
'coordinates'
],
dtype
=
dtype
)
energies
=
tensor
(
m
[
'energies'
],
dtype
=
dtype
)
_molecule_id
=
full_like
(
energies
,
molecule_id
).
type
(
long
)
datasets
.
append
(
TensorDataset
(
_molecule_id
,
coordinates
,
energies
))
species
.
append
(
m
[
'species'
])
molecule_id
+=
1
dataset
=
ConcatDataset
(
datasets
)
sizes
=
[
len
(
x
)
for
x
in
dataset
.
datasets
]
return
ANIDataset
(
dataset
,
sizes
,
species
)
# get name of files storing data
files
=
[]
if
isdir
(
path
):
for
f
in
listdir
(
path
):
f
=
join
(
path
,
f
)
if
isfile
(
f
)
and
(
f
.
endswith
(
'.h5'
)
or
f
.
endswith
(
'.hdf5'
)):
files
.
append
(
f
)
elif
isfile
(
path
):
files
=
[
path
]
else
:
raise
ValueError
(
'Bad path'
)
# generate chunks
chunks
=
[]
for
f
in
files
:
for
m
in
anidataloader
(
f
):
xyz
=
torch
.
from_numpy
(
m
[
'coordinates'
])
conformations
=
xyz
.
shape
[
0
]
energies
=
torch
.
from_numpy
(
m
[
'energies'
])
species
=
m
[
'species'
]
if
randomize_chunk
:
indices
=
torch
.
randperm
(
conformations
)
else
:
indices
=
torch
.
arange
(
conformations
,
dtype
=
torch
.
int64
)
num_chunks
=
(
conformations
+
chunk_size
-
1
)
//
chunk_size
for
i
in
range
(
num_chunks
):
chunk_start
=
i
*
chunk_size
chunk_end
=
min
(
chunk_start
+
chunk_size
,
conformations
)
chunk_indices
=
indices
[
chunk_start
:
chunk_end
]
chunk_xyz
=
xyz
.
index_select
(
0
,
chunk_indices
)
chunk_energies
=
energies
.
index_select
(
0
,
chunk_indices
)
chunks
.
append
((
chunk_xyz
,
chunk_energies
,
species
))
self
.
chunks
=
chunks
class
BatchSampler
(
object
):
def
__init__
(
self
,
source
,
chunk_size
,
batch_chunks
):
if
not
isinstance
(
source
,
ANIDataset
):
raise
ValueError
(
"BatchSampler must take ANIDataset as input"
)
self
.
source
=
source
self
.
chunk_size
=
chunk_size
self
.
batch_chunks
=
batch_chunks
def
_concated_index
(
self
,
molecule
,
conformation
):
"""
Get the index in the dataset of the specified conformation
of the specified molecule.
"""
src
=
self
.
source
cumulative_sizes
=
[
0
]
+
src
.
cumulative_sizes
return
cumulative_sizes
[
molecule
]
+
conformation
def
__iter__
(
self
):
molecules
=
len
(
self
.
source
.
sizes
)
sizes
=
self
.
source
.
sizes
"""Number of conformations of each molecule"""
unfinished
=
list
(
zip
(
range
(
molecules
),
[
0
]
*
molecules
))
"""List of pairs (molecule, progress) storing the current progress
of iterating each molecules."""
batch
=
[]
batch_molecules
=
0
"""The number of molecules already in batch"""
while
len
(
unfinished
)
>
0
:
new_unfinished
=
[]
for
molecule
,
progress
in
unfinished
:
size
=
sizes
[
molecule
]
# the last incomplete chunk is not dropped
end
=
min
(
progress
+
self
.
chunk_size
,
size
)
if
end
<
size
:
new_unfinished
.
append
((
molecule
,
end
))
batch
+=
[
self
.
_concated_index
(
molecule
,
x
)
for
x
in
range
(
progress
,
end
)]
batch_molecules
+=
1
if
batch_molecules
>=
self
.
batch_chunks
:
yield
batch
batch
=
[]
batch_molecules
=
0
unfinished
=
new_unfinished
# the last incomplete batch is not dropped
if
len
(
batch
)
>
0
:
yield
batch
def
__getitem__
(
self
,
idx
):
return
self
.
chunks
[
idx
]
def
__len__
(
self
):
sizes
=
self
.
source
.
sizes
chunks
=
[
ceil
(
x
/
self
.
chunk_size
)
for
x
in
sizes
]
chunks
=
sum
(
chunks
)
return
ceil
(
chunks
/
self
.
batch_chunks
)
def
collate
(
batch
):
by_molecules
=
{}
for
molecule_id
,
xyz
,
energy
in
batch
:
molecule_id
=
molecule_id
.
item
()
if
molecule_id
not
in
by_molecules
:
by_molecules
[
molecule_id
]
=
[]
by_molecules
[
molecule_id
].
append
((
xyz
,
energy
))
for
i
in
by_molecules
:
by_molecules
[
i
]
=
default_collate
(
by_molecules
[
i
])
return
by_molecules
def
random_split
(
dataset
,
num_chunks
,
chunk_size
):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths
The splitting is by chunk, which makes it possible for batching: The whole
dataset is first splitted into chunks of specified size, each chunk are
different conformation of the same isomer/molecule, then these chunks are
randomly shuffled and splitted accorting to the given `num_chunks`. After
splitted, chunks belong to the same molecule/isomer of the same subset will
be merged to allow larger batch.
Parameters
----------
dataset : Dataset:
Dataset to be split
num_chunks : sequence
Number of chuncks of splits to be produced
chunk_size : integer
Size of each chunk
"""
chunks
=
list
(
BatchSampler
(
dataset
,
chunk_size
,
1
))
shuffle
(
chunks
)
if
sum
(
num_chunks
)
!=
len
(
chunks
):
raise
ValueError
(
"""Sum of input number of chunks does not equal the length of the
total dataset!"""
)
offset
=
0
subsets
=
[]
for
i
in
num_chunks
:
_chunks
=
chunks
[
offset
:
offset
+
i
]
offset
+=
i
# merge chunks by molecule
by_molecules
=
{}
for
chunk
in
_chunks
:
molecule_id
=
dataset
[
chunk
[
0
]][
0
].
item
()
if
molecule_id
not
in
by_molecules
:
by_molecules
[
molecule_id
]
=
[]
by_molecules
[
molecule_id
]
+=
chunk
_chunks
=
list
(
by_molecules
.
values
())
shuffle
(
_chunks
)
# construct subset
sizes
=
[
len
(
j
)
for
j
in
_chunks
]
indices
=
list
(
chain
.
from_iterable
(
_chunks
))
_dataset
=
Subset
(
dataset
,
indices
)
_dataset
=
ANIDataset
(
_dataset
,
sizes
,
dataset
.
species
)
subsets
.
append
(
_dataset
)
return
subsets
return
len
(
self
.
chunks
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment