Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
b9e2c259
Commit
b9e2c259
authored
Aug 20, 2019
by
Richard Xue
Committed by
Gao, Xiang
Aug 20, 2019
Browse files
New dataset API, cached dataset and shuffled dataset (#284)
parent
f825c99e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
739 additions
and
89 deletions
+739
-89
azure/install_dependencies.sh
azure/install_dependencies.sh
+1
-1
azure/install_dependencies_python2.sh
azure/install_dependencies_python2.sh
+1
-1
docs/api.rst
docs/api.rst
+3
-0
setup.cfg
setup.cfg
+1
-1
setup.py
setup.py
+1
-0
tests/test_data_new.py
tests/test_data_new.py
+109
-0
tools/training-benchmark.py
tools/training-benchmark.py
+152
-85
torchani/data/__init__.py
torchani/data/__init__.py
+4
-1
torchani/data/new.py
torchani/data/new.py
+467
-0
No files found.
azure/install_dependencies.sh
View file @
b9e2c259
...
...
@@ -2,5 +2,5 @@
python
-m
pip
install
--upgrade
pip
pip
install
torch_nightly
-f
https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip
install
tqdm pyyaml future
pip
install
tqdm pyyaml future
pkbar
pip
install
'ase<=3.17'
\ No newline at end of file
azure/install_dependencies_python2.sh
View file @
b9e2c259
...
...
@@ -2,5 +2,5 @@
python
-m
pip
install
--upgrade
pip
pip
install
torch_nightly
-f
https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip
install
tqdm pyyaml future
pip
install
tqdm pyyaml future
pkbar
pip2
install
'ase<=3.17'
\ No newline at end of file
docs/api.rst
View file @
b9e2c259
...
...
@@ -24,6 +24,9 @@ Datasets
========
.. automodule:: torchani.data
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
...
...
setup.cfg
View file @
b9e2c259
...
...
@@ -5,7 +5,7 @@ with-coverage=1
cover-package=torchani
[flake8]
ignore = E501
ignore = E501
, W503
exclude =
.git,
__pycache__,
...
...
setup.py
View file @
b9e2c259
...
...
@@ -31,6 +31,7 @@ setup_attrs = {
'h5py'
,
'pytorch-ignite-nightly'
,
'pillow'
,
'pkbar'
],
}
...
...
tests/test_data_new.py
0 → 100644
View file @
b9e2c259
import
torchani
import
unittest
import
pkbar
import
torch
import
os
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
dspath
=
os
.
path
.
join
(
path
,
'../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5'
)
batch_size
=
2560
chunk_threshold
=
5
class
TestFindThreshold
(
unittest
.
TestCase
):
def
setUp
(
self
):
print
(
'.. check find threshold to split chunks'
)
def
testFindThreshould
(
self
):
torchani
.
data
.
find_threshold
(
dspath
,
batch_size
=
batch_size
,
threshold_max
=
10
)
class
TestShuffledData
(
unittest
.
TestCase
):
def
setUp
(
self
):
print
(
'.. setup shuffle dataset'
)
self
.
ds
=
torchani
.
data
.
ShuffledDataset
(
dspath
,
batch_size
=
batch_size
,
chunk_threshold
=
chunk_threshold
,
num_workers
=
2
)
self
.
chunks
,
self
.
properties
=
iter
(
self
.
ds
).
next
()
def
testTensorShape
(
self
):
print
(
'=> checking tensor shape'
)
print
(
'the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)'
)
batch_len
=
0
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
print
(
'chunk{}'
.
format
(
i
+
1
),
list
(
chunk
[
0
].
size
()),
chunk
[
0
].
dtype
,
list
(
chunk
[
1
].
size
()),
chunk
[
1
].
dtype
)
# check dtype
self
.
assertEqual
(
chunk
[
0
].
dtype
,
torch
.
int64
)
self
.
assertEqual
(
chunk
[
1
].
dtype
,
torch
.
float32
)
# check shape
self
.
assertEqual
(
chunk
[
1
].
shape
[
2
],
3
)
self
.
assertEqual
(
chunk
[
1
].
shape
[:
2
],
chunk
[
0
].
shape
[:
2
])
batch_len
+=
chunk
[
0
].
shape
[
0
]
for
key
,
value
in
self
.
properties
.
items
():
print
(
key
,
list
(
value
.
size
()),
value
.
dtype
)
self
.
assertEqual
(
value
.
dtype
,
torch
.
float32
)
self
.
assertEqual
(
len
(
value
.
shape
),
1
)
self
.
assertEqual
(
value
.
shape
[
0
],
batch_len
)
def
testLoadDataset
(
self
):
print
(
'=> test loading all dataset'
)
pbar
=
pkbar
.
Pbar
(
'loading and processing dataset into cpu memory, total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
.
ds
),
batch_size
),
len
(
self
.
ds
))
for
i
,
_
in
enumerate
(
self
.
ds
):
pbar
.
update
(
i
)
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
species
,
_
=
chunk
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
class
TestCachedData
(
unittest
.
TestCase
):
def
setUp
(
self
):
print
(
'.. setup cached dataset'
)
self
.
ds
=
torchani
.
data
.
CachedDataset
(
dspath
,
batch_size
=
batch_size
,
device
=
'cpu'
,
chunk_threshold
=
chunk_threshold
)
self
.
chunks
,
self
.
properties
=
self
.
ds
[
0
]
def
testTensorShape
(
self
):
print
(
'=> checking tensor shape'
)
print
(
'the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)'
)
batch_len
=
0
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
print
(
'chunk{}'
.
format
(
i
+
1
),
list
(
chunk
[
0
].
size
()),
chunk
[
0
].
dtype
,
list
(
chunk
[
1
].
size
()),
chunk
[
1
].
dtype
)
# check dtype
self
.
assertEqual
(
chunk
[
0
].
dtype
,
torch
.
int64
)
self
.
assertEqual
(
chunk
[
1
].
dtype
,
torch
.
float32
)
# check shape
self
.
assertEqual
(
chunk
[
1
].
shape
[
2
],
3
)
self
.
assertEqual
(
chunk
[
1
].
shape
[:
2
],
chunk
[
0
].
shape
[:
2
])
batch_len
+=
chunk
[
0
].
shape
[
0
]
for
key
,
value
in
self
.
properties
.
items
():
print
(
key
,
list
(
value
.
size
()),
value
.
dtype
)
self
.
assertEqual
(
value
.
dtype
,
torch
.
float32
)
self
.
assertEqual
(
len
(
value
.
shape
),
1
)
self
.
assertEqual
(
value
.
shape
[
0
],
batch_len
)
def
testLoadDataset
(
self
):
print
(
'=> test loading all dataset'
)
pbar
=
pkbar
.
Pbar
(
'loading and processing dataset into cpu memory, total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
.
ds
),
batch_size
),
len
(
self
.
ds
))
for
i
,
_
in
enumerate
(
self
.
ds
):
pbar
.
update
(
i
)
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
species
,
_
=
chunk
non_padding
=
(
species
>=
0
)[:,
-
1
].
nonzero
()
self
.
assertGreater
(
non_padding
.
numel
(),
0
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tools/training-benchmark.py
View file @
b9e2c259
import
torch
import
ignite
import
torchani
import
time
import
timeit
import
tqdm
import
argparse
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset_path'
,
help
=
'Path of the dataset, can a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'-d'
,
'--device'
,
help
=
'Device of modules and tensors'
,
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
.
add_argument
(
'--batch_size'
,
help
=
'Number of conformations of each batch'
,
default
=
1024
,
type
=
int
)
parser
=
parser
.
parse_args
()
# set up benchmark
device
=
torch
.
device
(
parser
.
device
)
ani1x
=
torchani
.
models
.
ANI1x
()
consts
=
ani1x
.
consts
aev_computer
=
ani1x
.
aev_computer
shift_energy
=
ani1x
.
energy_shifter
import
pkbar
def
atomic
():
...
...
@@ -39,45 +19,6 @@ def atomic():
return
model
model
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
],
x
[
1
].
flatten
()
nnp
=
torch
.
nn
.
Sequential
(
aev_computer
,
model
,
Flatten
()).
to
(
device
)
dataset
=
torchani
.
data
.
load_ani_dataset
(
parser
.
dataset_path
,
consts
.
species_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
init_tqdm
(
trainer
):
trainer
.
state
.
tqdm
=
tqdm
.
tqdm
(
total
=
len
(
dataset
),
desc
=
'epoch'
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
update_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
update
(
1
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_COMPLETED
)
def
finalize_tqdm
(
trainer
):
trainer
.
state
.
tqdm
.
close
()
timers
=
{}
def
time_func
(
key
,
func
):
timers
[
key
]
=
0
...
...
@@ -91,27 +32,153 @@ def time_func(key, func):
return
wrapper
# enable timers
torchani
.
aev
.
cutoff_cosine
=
time_func
(
'torchani.aev.cutoff_cosine'
,
torchani
.
aev
.
cutoff_cosine
)
torchani
.
aev
.
radial_terms
=
time_func
(
'torchani.aev.radial_terms'
,
torchani
.
aev
.
radial_terms
)
torchani
.
aev
.
angular_terms
=
time_func
(
'torchani.aev.angular_terms'
,
torchani
.
aev
.
angular_terms
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'torchani.aev.convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
# run it!
start
=
timeit
.
default_timer
()
trainer
.
run
(
dataset
,
max_epochs
=
1
)
elapsed
=
round
(
timeit
.
default_timer
()
-
start
,
2
)
for
k
in
timers
:
if
k
.
startswith
(
'torchani.'
):
print
(
k
,
timers
[
k
])
print
(
'Total AEV:'
,
timers
[
'total'
])
print
(
'NN:'
,
timers
[
'forward'
])
print
(
'Epoch time:'
,
elapsed
)
def
hartree2kcal
(
x
):
return
627.509
*
x
if
__name__
==
"__main__"
:
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset_path'
,
help
=
'Path of the dataset, can a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'-d'
,
'--device'
,
help
=
'Device of modules and tensors'
,
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
help
=
'Number of conformations of each batch'
,
default
=
2560
,
type
=
int
)
parser
.
add_argument
(
'-o'
,
'--original_dataset_api'
,
help
=
'use original dataset api'
,
dest
=
'dataset'
,
action
=
'store_const'
,
const
=
'original'
)
parser
.
add_argument
(
'-s'
,
'--shuffle_dataset_api'
,
help
=
'use shuffle dataset api'
,
dest
=
'dataset'
,
action
=
'store_const'
,
const
=
'shuffle'
)
parser
.
add_argument
(
'-c'
,
'--cache_dataset_api'
,
help
=
'use cache dataset api'
,
dest
=
'dataset'
,
action
=
'store_const'
,
const
=
'cache'
)
parser
.
set_defaults
(
dataset
=
'shuffle'
)
parser
.
add_argument
(
'-n'
,
'--num_epochs'
,
help
=
'epochs'
,
default
=
1
,
type
=
int
)
parser
=
parser
.
parse_args
()
Rcr
=
5.2000e+00
Rca
=
3.5000e+00
EtaR
=
torch
.
tensor
([
1.6000000e+01
],
device
=
parser
.
device
)
ShfR
=
torch
.
tensor
([
9.0000000e-01
,
1.1687500e+00
,
1.4375000e+00
,
1.7062500e+00
,
1.9750000e+00
,
2.2437500e+00
,
2.5125000e+00
,
2.7812500e+00
,
3.0500000e+00
,
3.3187500e+00
,
3.5875000e+00
,
3.8562500e+00
,
4.1250000e+00
,
4.3937500e+00
,
4.6625000e+00
,
4.9312500e+00
],
device
=
parser
.
device
)
Zeta
=
torch
.
tensor
([
3.2000000e+01
],
device
=
parser
.
device
)
ShfZ
=
torch
.
tensor
([
1.9634954e-01
,
5.8904862e-01
,
9.8174770e-01
,
1.3744468e+00
,
1.7671459e+00
,
2.1598449e+00
,
2.5525440e+00
,
2.9452431e+00
],
device
=
parser
.
device
)
EtaA
=
torch
.
tensor
([
8.0000000e+00
],
device
=
parser
.
device
)
ShfA
=
torch
.
tensor
([
9.0000000e-01
,
1.5500000e+00
,
2.2000000e+00
,
2.8500000e+00
],
device
=
parser
.
device
)
num_species
=
4
aev_computer
=
torchani
.
AEVComputer
(
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
)
nn
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
parser
.
device
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.000001
)
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
timers
=
{}
# enable timers
torchani
.
aev
.
cutoff_cosine
=
time_func
(
'torchani.aev.cutoff_cosine'
,
torchani
.
aev
.
cutoff_cosine
)
torchani
.
aev
.
radial_terms
=
time_func
(
'torchani.aev.radial_terms'
,
torchani
.
aev
.
radial_terms
)
torchani
.
aev
.
angular_terms
=
time_func
(
'torchani.aev.angular_terms'
,
torchani
.
aev
.
angular_terms
)
torchani
.
aev
.
compute_shifts
=
time_func
(
'torchani.aev.compute_shifts'
,
torchani
.
aev
.
compute_shifts
)
torchani
.
aev
.
neighbor_pairs
=
time_func
(
'torchani.aev.neighbor_pairs'
,
torchani
.
aev
.
neighbor_pairs
)
torchani
.
aev
.
triu_index
=
time_func
(
'torchani.aev.triu_index'
,
torchani
.
aev
.
triu_index
)
torchani
.
aev
.
convert_pair_index
=
time_func
(
'torchani.aev.convert_pair_index'
,
torchani
.
aev
.
convert_pair_index
)
torchani
.
aev
.
cumsum_from_zero
=
time_func
(
'torchani.aev.cumsum_from_zero'
,
torchani
.
aev
.
cumsum_from_zero
)
torchani
.
aev
.
triple_by_molecule
=
time_func
(
'torchani.aev.triple_by_molecule'
,
torchani
.
aev
.
triple_by_molecule
)
torchani
.
aev
.
compute_aev
=
time_func
(
'torchani.aev.compute_aev'
,
torchani
.
aev
.
compute_aev
)
model
[
0
].
forward
=
time_func
(
'total'
,
model
[
0
].
forward
)
model
[
1
].
forward
=
time_func
(
'forward'
,
model
[
1
].
forward
)
if
parser
.
dataset
==
'shuffle'
:
torchani
.
data
.
ShuffledDataset
=
time_func
(
'data_loading'
,
torchani
.
data
.
ShuffledDataset
)
print
(
'using shuffle dataset API'
)
print
(
'=> loading dataset...'
)
dataset
=
torchani
.
data
.
ShuffledDataset
(
file_path
=
parser
.
dataset_path
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
True
,
batch_size
=
parser
.
batch_size
,
num_workers
=
2
)
print
(
'=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)'
)
chunks
,
properties
=
iter
(
dataset
).
next
()
elif
parser
.
dataset
==
'original'
:
torchani
.
data
.
load_ani_dataset
=
time_func
(
'data_loading'
,
torchani
.
data
.
load_ani_dataset
)
print
(
'using original dataset API'
)
print
(
'=> loading dataset...'
)
energy_shifter
=
torchani
.
utils
.
EnergyShifter
(
None
)
species_to_tensor
=
torchani
.
utils
.
ChemicalSymbolsToInts
(
'HCNO'
)
dataset
=
torchani
.
data
.
load_ani_dataset
(
parser
.
dataset_path
,
species_to_tensor
,
parser
.
batch_size
,
device
=
parser
.
device
,
transform
=
[
energy_shifter
.
subtract_from_dataset
])
print
(
'=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)'
)
chunks
,
properties
=
dataset
[
0
]
elif
parser
.
dataset
==
'cache'
:
torchani
.
data
.
CachedDataset
=
time_func
(
'data_loading'
,
torchani
.
data
.
CachedDataset
)
print
(
'using cache dataset API'
)
print
(
'=> loading dataset...'
)
dataset
=
torchani
.
data
.
CachedDataset
(
file_path
=
parser
.
dataset_path
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
True
,
batch_size
=
parser
.
batch_size
)
print
(
'=> caching all dataset into cpu'
)
pbar
=
pkbar
.
Pbar
(
'loading and processing dataset into cpu memory, total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
dataset
),
parser
.
batch_size
),
len
(
dataset
))
for
i
,
t
in
enumerate
(
dataset
):
pbar
.
update
(
i
)
print
(
'=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)'
)
chunks
,
properties
=
dataset
[
0
]
for
i
,
chunk
in
enumerate
(
chunks
):
print
(
'chunk{}'
.
format
(
i
+
1
),
list
(
chunk
[
0
].
size
()),
list
(
chunk
[
1
].
size
()))
print
(
'energies'
,
list
(
properties
[
'energies'
].
size
()))
print
(
'=> start training'
)
start
=
time
.
time
()
for
epoch
in
range
(
0
,
parser
.
num_epochs
):
print
(
'Epoch: %d/%d'
%
(
epoch
+
1
,
parser
.
num_epochs
))
progbar
=
pkbar
.
Kbar
(
target
=
len
(
dataset
)
-
1
,
width
=
8
)
for
i
,
(
batch_x
,
batch_y
)
in
enumerate
(
dataset
):
true_energies
=
batch_y
[
'energies'
].
to
(
parser
.
device
)
predicted_energies
=
[]
num_atoms
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
chunk_species
=
chunk_species
.
to
(
parser
.
device
)
chunk_coordinates
=
chunk_coordinates
.
to
(
parser
.
device
)
num_atoms
.
append
((
chunk_species
>=
0
).
to
(
true_energies
.
dtype
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
.
sqrt
()).
mean
()
rmse
=
hartree2kcal
((
mse
(
predicted_energies
,
true_energies
)).
mean
()).
detach
().
cpu
().
numpy
()
loss
.
backward
()
optimizer
.
step
()
progbar
.
update
(
i
,
values
=
[(
"rmse"
,
rmse
)])
stop
=
time
.
time
()
print
(
'=> more detail about benchmark'
)
for
k
in
timers
:
if
k
.
startswith
(
'torchani.'
):
print
(
'{} - {:.1f}s'
.
format
(
k
,
timers
[
k
]))
print
(
'Total AEV - {:.1f}s'
.
format
(
timers
[
'total'
]))
print
(
'Data Loading - {:.1f}s'
.
format
(
timers
[
'data_loading'
]))
print
(
'NN - {:.1f}s'
.
format
(
timers
[
'forward'
]))
print
(
'Epoch time - {:.1f}s'
.
format
(
stop
-
start
))
torchani/data/__init__.py
View file @
b9e2c259
...
...
@@ -11,6 +11,7 @@ import pickle
import
numpy
as
np
from
scipy.sparse
import
bsr_matrix
import
warnings
from
.new
import
CachedDataset
,
ShuffledDataset
,
find_threshold
default_device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
...
...
@@ -511,4 +512,6 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
SparseAEVCacheLoader
.
encode_aev
,
**
kwargs
)
__all__
=
[
'load_ani_dataset'
,
'BatchedANIDataset'
,
'AEVCacheLoader'
,
'SparseAEVCacheLoader'
,
'cache_aev'
,
'cache_sparse_aev'
]
__all__
=
[
'load_ani_dataset'
,
'BatchedANIDataset'
,
'AEVCacheLoader'
,
'SparseAEVCacheLoader'
,
'cache_aev'
,
'cache_sparse_aev'
,
'CachedDataset'
,
'ShuffledDataset'
,
'find_threshold'
]
torchani/data/new.py
0 → 100644
View file @
b9e2c259
import
numpy
as
np
import
torch
import
functools
from
._pyanitools
import
anidataloader
import
importlib
import
gc
PKBAR_INSTALLED
=
importlib
.
util
.
find_spec
(
'pkbar'
)
is
not
None
if
PKBAR_INSTALLED
:
import
pkbar
def
find_threshold
(
file_path
,
batch_size
,
threshold_max
=
100
):
"""Find resonable threshold to split chunks before using ``torchani.data.CachedDataset`` or ``torchani.data.ShuffledDataset``.
Arguments:
file_path (str): Path to one hdf5 files.
batch_size (int): batch size.
threshold_max (int): max threshould to test.
"""
ds
=
CachedDataset
(
file_path
=
file_path
,
batch_size
=
batch_size
)
ds
.
find_threshold
(
threshold_max
+
1
)
class
CachedDataset
(
torch
.
utils
.
data
.
Dataset
):
""" Cached Dataset which is shuffled once, but the dataset keeps the same at every epoch.
Arguments:
file_path (str): Path to one hdf5 file.
batch_size (int): batch size.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
self_energies (list): if `subtract_self_energies` is True, the order should keep
the same as ``species_order``.
for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be converted
to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``..
.. note::
The resulting dataset will be:
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``.
e.g. the shape of
chunk1: ``[[1807, 21], [1807, 21, 3]]``
chunk2: ``[[193, 50], [193, 50, 3]]``
'energies': ``[2000, 1]``
"""
def
__init__
(
self
,
file_path
,
batch_size
=
1000
,
device
=
'cpu'
,
chunk_threshold
=
20
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
False
,
self_energies
=
[
-
0.600953
,
-
38.08316
,
-
54.707756
,
-
75.194466
]):
super
(
CachedDataset
,
self
).
__init__
()
# example of species_dict will looks like
# species_dict: {'H': 0, 'C': 1, 'N': 2, 'O': 3}
# self_energies_dict: {'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}
species_dict
=
{}
self_energies_dict
=
{}
for
i
,
s
in
enumerate
(
species_order
):
species_dict
[
s
]
=
i
self_energies_dict
[
s
]
=
self_energies
[
i
]
self
.
data_species
=
[]
self
.
data_coordinates
=
[]
self
.
data_energies
=
[]
self
.
data_self_energies
=
[]
anidata
=
anidataloader
(
file_path
)
anidata_size
=
anidata
.
group_size
()
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
if
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> loading h5 dataset into cpu memory, total molecules: {}'
.
format
(
anidata_size
),
anidata_size
)
for
i
,
molecule
in
enumerate
(
anidata
):
num_conformations
=
len
(
molecule
[
'coordinates'
])
# species and coordinates
self
.
data_coordinates
+=
list
(
molecule
[
'coordinates'
].
reshape
(
num_conformations
,
-
1
).
astype
(
np
.
float32
))
species
=
np
.
array
([
species_dict
[
x
]
for
x
in
molecule
[
'species'
]])
self
.
data_species
+=
list
(
np
.
tile
(
species
,
(
num_conformations
,
1
)))
# energies
self
.
data_energies
+=
list
(
molecule
[
'energies'
].
reshape
((
-
1
,
1
)))
if
subtract_self_energies
:
self_energies
=
np
.
array
(
sum
([
self_energies_dict
[
x
]
for
x
in
molecule
[
'species'
]]))
self
.
data_self_energies
+=
list
(
np
.
tile
(
self_energies
,
(
num_conformations
,
1
)))
if
enable_pkbar
:
pbar
.
update
(
i
)
if
subtract_self_energies
:
self
.
data_energies
=
np
.
array
(
self
.
data_energies
)
-
np
.
array
(
self
.
data_self_energies
)
del
self
.
data_self_energies
del
self_energies
gc
.
collect
()
self
.
batch_size
=
batch_size
self
.
length
=
(
len
(
self
.
data_species
)
+
self
.
batch_size
-
1
)
//
self
.
batch_size
self
.
device
=
device
self
.
shuffled_index
=
np
.
arange
(
len
(
self
.
data_species
))
np
.
random
.
shuffle
(
self
.
shuffled_index
)
self
.
chunk_threshold
=
chunk_threshold
if
not
self
.
chunk_threshold
:
self
.
chunk_threshold
=
np
.
inf
anidata
.
cleanup
()
del
num_conformations
del
species
del
anidata
gc
.
collect
()
@
functools
.
lru_cache
(
maxsize
=
None
)
def
__getitem__
(
self
,
index
):
if
index
>=
self
.
length
:
raise
IndexError
()
batch_indices
=
slice
(
index
*
self
.
batch_size
,
(
index
+
1
)
*
self
.
batch_size
)
batch_indices_shuffled
=
self
.
shuffled_index
[
batch_indices
]
batch_species
=
[
self
.
data_species
[
i
]
for
i
in
batch_indices_shuffled
]
batch_coordinates
=
[
self
.
data_coordinates
[
i
]
for
i
in
batch_indices_shuffled
]
batch_energies
=
[
self
.
data_energies
[
i
]
for
i
in
batch_indices_shuffled
]
# get sort index
num_atoms_each_mole
=
[
b
.
shape
[
0
]
for
b
in
batch_species
]
atoms
=
torch
.
tensor
(
num_atoms_each_mole
,
dtype
=
torch
.
int32
)
sorted_atoms
,
sorted_atoms_idx
=
torch
.
sort
(
atoms
)
# sort each batch of data
batch_species
=
self
.
sort_list_with_index
(
batch_species
,
sorted_atoms_idx
.
numpy
())
batch_coordinates
=
self
.
sort_list_with_index
(
batch_coordinates
,
sorted_atoms_idx
.
numpy
())
batch_energies
=
self
.
sort_list_with_index
(
batch_energies
,
sorted_atoms_idx
.
numpy
())
# get chunk size
output
,
count
=
torch
.
unique
(
atoms
,
sorted
=
True
,
return_counts
=
True
)
counts
=
torch
.
cat
((
output
.
unsqueeze
(
-
1
).
int
(),
count
.
unsqueeze
(
-
1
).
int
()),
dim
=-
1
)
chunk_size_list
,
chunk_max_list
=
split_to_chunks
(
counts
,
chunk_threshold
=
self
.
chunk_threshold
*
self
.
batch_size
*
20
)
chunk_size_list
=
torch
.
stack
(
chunk_size_list
).
flatten
()
# split into chunks
chunks_batch_species
=
self
.
split_list_with_size
(
batch_species
,
chunk_size_list
.
numpy
())
chunks_batch_coordinates
=
self
.
split_list_with_size
(
batch_coordinates
,
chunk_size_list
.
numpy
())
batch_energies
=
self
.
split_list_with_size
(
batch_energies
,
np
.
array
([
self
.
batch_size
]))
# padding each data
chunks_batch_species
=
self
.
pad_and_convert_to_tensor
(
chunks_batch_species
,
padding_value
=-
1
)
chunks_batch_coordinates
=
self
.
pad_and_convert_to_tensor
(
chunks_batch_coordinates
)
batch_energies
=
self
.
pad_and_convert_to_tensor
(
batch_energies
,
no_padding
=
True
)
chunks
=
list
(
zip
(
chunks_batch_species
,
chunks_batch_coordinates
))
for
i
,
_
in
enumerate
(
chunks
):
chunks
[
i
]
=
(
chunks
[
i
][
0
],
chunks
[
i
][
1
].
reshape
(
chunks
[
i
][
1
].
shape
[
0
],
-
1
,
3
))
properties
=
{
'energies'
:
batch_energies
[
0
].
flatten
().
float
()}
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
# 'energies' = [2000, 1]
return
chunks
,
properties
def
__len__
(
self
):
return
self
.
length
@
staticmethod
def
sort_list_with_index
(
inputs
,
index
):
return
[
inputs
[
i
]
for
i
in
index
]
@
staticmethod
def
split_list_with_size
(
inputs
,
split_size
):
output
=
[]
for
i
,
_
in
enumerate
(
split_size
):
start_index
=
np
.
sum
(
split_size
[:
i
])
stop_index
=
np
.
sum
(
split_size
[:
i
+
1
])
output
.
append
(
inputs
[
start_index
:
stop_index
])
return
output
def
pad_and_convert_to_tensor
(
self
,
inputs
,
padding_value
=
0
,
no_padding
=
False
):
if
no_padding
:
for
i
,
input_tmp
in
enumerate
(
inputs
):
inputs
[
i
]
=
torch
.
from_numpy
(
np
.
stack
(
input_tmp
)).
to
(
self
.
device
)
else
:
for
i
,
input_tmp
in
enumerate
(
inputs
):
inputs
[
i
]
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
[
torch
.
from_numpy
(
b
)
for
b
in
inputs
[
i
]],
batch_first
=
True
,
padding_value
=
padding_value
).
to
(
self
.
device
)
return
inputs
def
find_threshold
(
self
,
threshold_max
=
100
):
batch_indices
=
slice
(
0
,
self
.
batch_size
)
batch_indices_shuffled
=
self
.
shuffled_index
[
batch_indices
]
batch_species
=
[
self
.
data_species
[
i
]
for
i
in
batch_indices_shuffled
]
num_atoms_each_mole
=
[
b
.
shape
[
0
]
for
b
in
batch_species
]
atoms
=
torch
.
tensor
(
num_atoms_each_mole
,
dtype
=
torch
.
int32
)
output
,
count
=
torch
.
unique
(
atoms
,
sorted
=
True
,
return_counts
=
True
)
counts
=
torch
.
cat
((
output
.
unsqueeze
(
-
1
).
int
(),
count
.
unsqueeze
(
-
1
).
int
()),
dim
=-
1
)
print
(
'=> choose a reasonable threshold to split chunks'
)
print
(
'format is [chunk_size, chunk_max]'
)
for
b
in
range
(
0
,
threshold_max
,
1
):
test_chunk_size_list
,
test_chunk_max_list
=
split_to_chunks
(
counts
,
chunk_threshold
=
b
*
self
.
batch_size
*
20
)
size_max
=
[]
for
i
,
_
in
enumerate
(
test_chunk_size_list
):
size_max
.
append
([
list
(
test_chunk_size_list
[
i
].
numpy
())[
0
],
list
(
test_chunk_max_list
[
i
].
numpy
())[
0
]])
print
(
'chunk_threshold = {}'
.
format
(
b
))
print
(
size_max
)
def
release_h5
(
self
):
del
self
.
data_species
del
self
.
data_coordinates
del
self
.
data_energies
gc
.
collect
()
def
ShuffledDataset
(
file_path
,
batch_size
=
1000
,
num_workers
=
0
,
shuffle
=
True
,
chunk_threshold
=
20
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
False
,
self_energies
=
[
-
0.600953
,
-
38.08316
,
-
54.707756
,
-
75.194466
]):
""" Shuffled Dataset which using `torch.utils.data.DataLoader`, it will shuffle at every epoch.
Arguments:
file_path (str): Path to one hdf5 file.
batch_size (int): batch size.
num_workers (int): multiple process to prepare dataset at background when
training is going.
shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
will not split chunks.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
self_energies (list): if `subtract_self_energies` is True, the order should keep
the same as ``species_order``.
for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be
converted to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``.
.. note::
Return a dataloader that, when iterating, you will get
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``.
e.g. the shape of
chunk1: ``[[1807, 21], [1807, 21, 3]]``
chunk2: ``[[193, 50], [193, 50, 3]]``
'energies': ``[2000, 1]``
"""
dataset
=
TorchData
(
file_path
,
species_order
,
subtract_self_energies
,
self_energies
)
if
not
chunk_threshold
:
chunk_threshold
=
np
.
inf
def
my_collate_fn
(
data
,
chunk_threshold
=
chunk_threshold
):
return
collate_fn
(
data
,
chunk_threshold
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
num_workers
=
num_workers
,
pin_memory
=
False
,
collate_fn
=
my_collate_fn
)
return
data_loader
class
TorchData
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
file_path
,
species_order
,
subtract_self_energies
,
self_energies
):
super
(
TorchData
,
self
).
__init__
()
species_dict
=
{}
self_energies_dict
=
{}
for
i
,
s
in
enumerate
(
species_order
):
species_dict
[
s
]
=
i
self_energies_dict
[
s
]
=
self_energies
[
i
]
self
.
data_species
=
[]
self
.
data_coordinates
=
[]
self
.
data_energies
=
[]
self
.
data_self_energies
=
[]
anidata
=
anidataloader
(
file_path
)
anidata_size
=
anidata
.
group_size
()
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
if
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> loading h5 dataset into cpu memory, total molecules: {}'
.
format
(
anidata_size
),
anidata_size
)
for
i
,
molecule
in
enumerate
(
anidata
):
num_conformations
=
len
(
molecule
[
'coordinates'
])
self
.
data_coordinates
+=
list
(
molecule
[
'coordinates'
].
reshape
(
num_conformations
,
-
1
).
astype
(
np
.
float32
))
self
.
data_energies
+=
list
(
molecule
[
'energies'
].
reshape
((
-
1
,
1
)))
species
=
np
.
array
([
species_dict
[
x
]
for
x
in
molecule
[
'species'
]])
self
.
data_species
+=
list
(
np
.
tile
(
species
,
(
num_conformations
,
1
)))
if
subtract_self_energies
:
self_energies
=
np
.
array
(
sum
([
self_energies_dict
[
x
]
for
x
in
molecule
[
'species'
]]))
self
.
data_self_energies
+=
list
(
np
.
tile
(
self_energies
,
(
num_conformations
,
1
)))
if
enable_pkbar
:
pbar
.
update
(
i
)
if
subtract_self_energies
:
self
.
data_energies
=
np
.
array
(
self
.
data_energies
)
-
np
.
array
(
self
.
data_self_energies
)
del
self
.
data_self_energies
del
self_energies
gc
.
collect
()
self
.
length
=
len
(
self
.
data_species
)
anidata
.
cleanup
()
del
num_conformations
del
species
del
anidata
gc
.
collect
()
def
__getitem__
(
self
,
index
):
if
index
>=
self
.
length
:
raise
IndexError
()
species
=
torch
.
from_numpy
(
self
.
data_species
[
index
])
coordinates
=
torch
.
from_numpy
(
self
.
data_coordinates
[
index
]).
float
()
energies
=
torch
.
from_numpy
(
self
.
data_energies
[
index
]).
float
()
return
[
species
,
coordinates
,
energies
]
def
__len__
(
self
):
return
self
.
length
def
collate_fn
(
data
,
chunk_threshold
):
"""Creates a batch of chunked data.
"""
# unzip a batch of molecules (each molecule is a list)
batch_species
,
batch_coordinates
,
batch_energies
=
zip
(
*
data
)
batch_size
=
len
(
batch_species
)
# padding - time: 13.2s
batch_species
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
batch_species
,
batch_first
=
True
,
padding_value
=-
1
)
batch_coordinates
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
batch_coordinates
,
batch_first
=
True
,
padding_value
=
0
)
batch_energies
=
torch
.
stack
(
batch_energies
)
# sort - time: 0.7s
atoms
=
torch
.
sum
(
~
(
batch_species
==
-
1
),
dim
=-
1
,
dtype
=
torch
.
int32
)
sorted_atoms
,
sorted_atoms_idx
=
torch
.
sort
(
atoms
)
batch_species
=
torch
.
index_select
(
batch_species
,
dim
=
0
,
index
=
sorted_atoms_idx
)
batch_coordinates
=
torch
.
index_select
(
batch_coordinates
,
dim
=
0
,
index
=
sorted_atoms_idx
)
batch_energies
=
torch
.
index_select
(
batch_energies
,
dim
=
0
,
index
=
sorted_atoms_idx
)
# get chunk size - time: 2.1s
output
,
count
=
torch
.
unique
(
atoms
,
sorted
=
True
,
return_counts
=
True
)
counts
=
torch
.
cat
((
output
.
unsqueeze
(
-
1
).
int
(),
count
.
unsqueeze
(
-
1
).
int
()),
dim
=-
1
)
chunk_size_list
,
chunk_max_list
=
split_to_chunks
(
counts
,
chunk_threshold
=
chunk_threshold
*
batch_size
*
20
)
# split into chunks - time: 0.3s
chunks_batch_species
=
torch
.
split
(
batch_species
,
chunk_size_list
,
dim
=
0
)
chunks_batch_coordinates
=
torch
.
split
(
batch_coordinates
,
chunk_size_list
,
dim
=
0
)
# truncate redundant padding - time: 1.3s
chunks_batch_species
=
trunc_pad
(
list
(
chunks_batch_species
),
padding_value
=-
1
)
chunks_batch_coordinates
=
trunc_pad
(
list
(
chunks_batch_coordinates
))
for
i
,
c
in
enumerate
(
chunks_batch_coordinates
):
chunks_batch_coordinates
[
i
]
=
c
.
reshape
(
c
.
shape
[
0
],
-
1
,
3
)
chunks
=
list
(
zip
(
chunks_batch_species
,
chunks_batch_coordinates
))
for
i
,
_
in
enumerate
(
chunks
):
chunks
[
i
]
=
(
chunks
[
i
][
0
],
chunks
[
i
][
1
])
properties
=
{
'energies'
:
batch_energies
.
flatten
().
float
()}
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
# 'energies' = [2000, 1]
return
chunks
,
properties
def
split_to_two_chunks
(
counts
,
chunk_threshold
):
counts
=
counts
.
cpu
()
# NB (@yueyericardo): In principle this dtype should be `torch.bool`, but unfortunately
# `triu` is not implemented for bool tensor right now. This should be fixed when PyTorch
# add support for it.
left_mask
=
torch
.
triu
(
torch
.
ones
([
counts
.
shape
[
0
],
counts
.
shape
[
0
]],
dtype
=
torch
.
uint8
))
left_mask
=
left_mask
.
t
()
counts_atoms
=
counts
[:,
0
].
repeat
(
counts
.
shape
[
0
],
1
)
counts_counts
=
counts
[:,
1
].
repeat
(
counts
.
shape
[
0
],
1
)
counts_atoms_left
=
torch
.
where
(
left_mask
,
counts_atoms
,
torch
.
zeros_like
(
counts_atoms
))
counts_atoms_right
=
torch
.
where
(
~
left_mask
,
counts_atoms
,
torch
.
zeros_like
(
counts_atoms
))
counts_counts_left
=
torch
.
where
(
left_mask
,
counts_counts
,
torch
.
zeros_like
(
counts_atoms
))
counts_counts_right
=
torch
.
where
(
~
left_mask
,
counts_counts
,
torch
.
zeros_like
(
counts_atoms
))
# chunk max
chunk_max_left
=
torch
.
max
(
counts_atoms_left
,
dim
=-
1
,
keepdim
=
True
).
values
chunk_max_right
=
torch
.
max
(
counts_atoms_right
,
dim
=-
1
,
keepdim
=
True
).
values
# chunk size
chunk_size_left
=
torch
.
sum
(
counts_counts_left
,
dim
=-
1
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
chunk_size_right
=
torch
.
sum
(
counts_counts_right
,
dim
=-
1
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
# calculate cost
min_cost_threshold
=
torch
.
tensor
([
chunk_threshold
],
dtype
=
torch
.
int32
)
cost
=
(
torch
.
max
(
chunk_size_left
*
chunk_max_left
*
chunk_max_left
,
min_cost_threshold
)
+
torch
.
max
(
chunk_size_right
*
chunk_max_right
*
chunk_max_right
,
min_cost_threshold
))
# find smallest cost
cost_min
,
cost_min_index
=
torch
.
min
(
cost
.
squeeze
(),
dim
=-
1
)
# find smallest cost chunk_size, if not splitted, it will be [max_chunk_size, 0]
final_chunk_size
=
[
chunk_size_left
[
cost_min_index
],
chunk_size_right
[
cost_min_index
]]
final_chunk_max
=
[
chunk_max_left
[
cost_min_index
],
chunk_max_right
[
cost_min_index
]]
# if not splitted
if
cost_min_index
==
(
counts
.
shape
[
0
]
-
1
):
return
False
,
counts
,
[
final_chunk_size
[
0
]],
[
final_chunk_max
[
0
]],
cost_min
# if splitted
return
True
,
[
counts
[:
cost_min_index
+
1
],
counts
[(
cost_min_index
+
1
):]],
\
final_chunk_size
,
final_chunk_max
,
cost_min
def
split_to_chunks
(
counts
,
chunk_threshold
=
np
.
inf
):
splitted
,
counts_list
,
chunk_size
,
chunk_max
,
cost
=
split_to_two_chunks
(
counts
,
chunk_threshold
)
final_chunk_size
=
[]
final_chunk_max
=
[]
if
(
splitted
):
for
i
,
_
in
enumerate
(
counts_list
):
tmp_chunk_size
,
tmp_chunk_max
=
split_to_chunks
(
counts_list
[
i
],
chunk_threshold
)
final_chunk_size
.
extend
(
tmp_chunk_size
)
final_chunk_max
.
extend
(
tmp_chunk_max
)
return
final_chunk_size
,
final_chunk_max
# if not splitted
return
chunk_size
,
chunk_max
def
trunc_pad
(
chunks
,
padding_value
=
0
):
for
i
,
_
in
enumerate
(
chunks
):
lengths
=
torch
.
sum
(
~
(
chunks
[
i
]
==
padding_value
),
dim
=-
1
,
dtype
=
torch
.
int32
)
chunks
[
i
]
=
chunks
[
i
][...,
:
lengths
.
max
()]
return
chunks
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment