Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
3cced1e6
Unverified
Commit
3cced1e6
authored
Aug 23, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 23, 2018
Browse files
Improve training related API (#76)
parent
d7ef8182
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
98 additions
and
125 deletions
+98
-125
codefresh.yml
codefresh.yml
+3
-3
examples/energy_force.py
examples/energy_force.py
+1
-1
examples/neurochem-test.py
examples/neurochem-test.py
+4
-4
examples/nnp_training.py
examples/nnp_training.py
+37
-25
examples/training-benchmark.py
examples/training-benchmark.py
+4
-4
tests/test_data.py
tests/test_data.py
+4
-5
tests/test_ignite.py
tests/test_ignite.py
+7
-7
torchani/__init__.py
torchani/__init__.py
+3
-2
torchani/_pyanitools.py
torchani/_pyanitools.py
+0
-0
torchani/data.py
torchani/data.py
+13
-38
torchani/ignite.py
torchani/ignite.py
+20
-1
torchani/neurochem.py
torchani/neurochem.py
+2
-2
torchani/training/__init__.py
torchani/training/__init__.py
+0
-9
torchani/training/container.py
torchani/training/container.py
+0
-24
No files found.
codefresh.yml
View file @
3cced1e6
...
...
@@ -24,9 +24,9 @@ steps:
Examples
:
image
:
'
${{BuildTorchANI}}'
commands
:
-
rm -rf
*.dat
*.pt
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
rm -rf *.pt
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/energy_force.py
...
...
examples/energy_force.py
View file @
3cced1e6
...
...
@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
requires_grad
=
True
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
,
device
).
unsqueeze
(
0
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
).
to
(
device
).
unsqueeze
(
0
)
_
,
energy
=
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
...
...
examples/neurochem-test.py
View file @
3cced1e6
...
...
@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
)
nn
=
torchani
.
neurochem
.
load_model
(
consts
.
species
,
parser
.
network_dir
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
)
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
container
=
container
.
to
(
device
)
# load datasets
if
parser
.
dataset_path
.
endswith
(
'.h5'
)
or
\
parser
.
dataset_path
.
endswith
(
'.hdf5'
)
or
\
os
.
path
.
isdir
(
parser
.
dataset_path
):
dataset
=
torchani
.
training
.
BatchedANIDataset
(
parser
.
dataset_path
,
consts
.
species
,
parser
.
batch_size
,
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
datasets
=
[
dataset
]
else
:
...
...
@@ -60,7 +60,7 @@ def hartree2kcal(x):
for
dataset
in
datasets
:
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
...
...
examples/nnp_training.py
View file @
3cced1e6
...
...
@@ -11,18 +11,18 @@ import json
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset_path'
,
help
=
'Path of the dataset, can a hdf5 file
\
parser
.
add_argument
(
'training_path'
,
help
=
'Path of the training set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'validation_path'
,
help
=
'Path of the validation set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'--dataset_checkpoint'
,
help
=
'Checkpoint file for datasets'
,
default
=
'dataset-checkpoint.dat'
)
parser
.
add_argument
(
'--model_checkpoint'
,
help
=
'Checkpoint file for model'
,
default
=
'model.pt'
)
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
help
=
'Maximum number of epoches'
,
default
=
1
00
,
type
=
int
)
default
=
3
00
,
type
=
int
)
parser
.
add_argument
(
'--training_rmse_every'
,
help
=
'Compute training RMSE every epoches'
,
default
=
20
,
type
=
int
)
...
...
@@ -53,20 +53,24 @@ start = timeit.default_timer()
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
model
.
consts
.
species
,
parser
.
dataset_path
,
device
=
device
,
training
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
training_path
,
model
.
consts
.
species_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
validation
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
validation_path
,
model
.
consts
.
species_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
...
...
@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
validation_and_checkpoint
(
trainer
):
def
evaluate
(
dataset
,
name
):
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
}
)
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
name
,
rmse
,
trainer
.
state
.
epoch
)
return
rmse
# compute validation RMSE
evaluator
.
run
(
validation
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'validation_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
rmse
=
evaluate
(
validation
,
'validation_rmse_vs_epoch'
)
# compute training RMSE
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
0
:
evaluator
.
run
(
training
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'training_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
1
:
evaluate
(
training
,
'training_rmse_vs_epoch'
)
# handle best validation RMSE
if
rmse
<
trainer
.
state
.
best_validation_rmse
:
...
...
@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
torch
.
save
(
nnp
.
state_dict
(),
parser
.
model_checkpoint
)
else
:
trainer
.
state
.
no_improve_count
+=
1
writer
.
add_scalar
(
'no_improve_count_vs_epoch'
,
trainer
.
state
.
no_improve_count
,
trainer
.
state
.
epoch
)
if
trainer
.
state
.
no_improve_count
>
parser
.
early_stopping
:
trainer
.
terminate
()
trainer
.
terminate
()
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
...
@@ -134,8 +147,7 @@ def log_time(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
log_loss_and_time
(
trainer
):
iteration
=
trainer
.
state
.
iteration
rmse
=
hartree2kcal
(
math
.
sqrt
(
trainer
.
state
.
output
))
writer
.
add_scalar
(
'training_atomic_rmse_vs_iteration'
,
rmse
,
iteration
)
writer
.
add_scalar
(
'loss_vs_iteration'
,
trainer
.
state
.
output
,
iteration
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
examples/training-benchmark.py
View file @
3cced1e6
...
...
@@ -23,15 +23,15 @@ parser = parser.parse_args()
device
=
torch
.
device
(
parser
.
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
dataset
=
torchani
.
training
.
BatchedANIDataset
(
parser
.
dataset_path
,
model
.
consts
.
species
,
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
model
.
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
...
tests/test_data.py
View file @
3cced1e6
...
...
@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
class
TestData
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
ds
=
torchani
.
training
.
BatchedANIDataset
(
dataset_path
,
consts
.
species
,
batch_size
)
self
.
ds
=
torchani
.
data
.
BatchedANIDataset
(
dataset_path
,
consts
.
species
_to_tensor
,
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
...
...
@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
(
species3
,
coordinates3
),
])
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
training
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
chunks
=
torchani
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
start
=
0
last
=
None
for
s
,
c
in
chunks
:
...
...
tests/test_ignite.py
View file @
3cced1e6
...
...
@@ -5,7 +5,7 @@ import copy
from
ignite.engine
import
create_supervised_trainer
,
\
create_supervised_evaluator
,
Events
import
torchani
import
torchani.
training
import
torchani.
ignite
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
...
...
@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
aev_computer
=
torchani
.
buildins
.
aev_computer
nnp
=
copy
.
deepcopy
(
torchani
.
buildins
.
models
[
0
])
shift_energy
=
torchani
.
buildins
.
energy_shifter
ds
=
torchani
.
training
.
BatchedANIDataset
(
path
,
torchani
.
buildins
.
consts
.
species
,
batchsize
,
ds
=
torchani
.
data
.
BatchedANIDataset
(
path
,
torchani
.
buildins
.
consts
.
species
_to_tensor
,
batchsize
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
...
...
@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
return
x
[
0
],
x
[
1
].
flatten
()
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nnp
,
Flatten
())
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
loss
=
torchani
.
training
.
TransformedLoss
(
torchani
.
training
.
MSELoss
(
'energies'
),
loss
=
torchani
.
ignite
.
TransformedLoss
(
torchani
.
ignite
.
MSELoss
(
'energies'
),
lambda
x
:
torch
.
exp
(
x
)
-
1
)
trainer
=
create_supervised_trainer
(
container
,
optimizer
,
loss
)
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
@
trainer
.
on
(
Events
.
COMPLETED
)
...
...
torchani/__init__.py
View file @
3cced1e6
from
.utils
import
EnergyShifter
from
.models
import
ANIModel
,
Ensemble
from
.aev
import
AEVComputer
from
.
import
training
from
.
import
ignite
from
.
import
utils
from
.
import
neurochem
from
.
import
data
from
.neurochem
import
buildins
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'buildins'
,
'
training
'
,
'utils'
,
'neurochem'
]
'
ignite
'
,
'utils'
,
'neurochem'
,
'data'
]
torchani/
training/
pyanitools.py
→
torchani/
_
pyanitools.py
View file @
3cced1e6
File moved
torchani/
training/
data.py
→
torchani/data.py
View file @
3cced1e6
from
torch.utils.data
import
Dataset
from
os.path
import
join
,
isfile
,
isdir
import
os
from
.pyanitools
import
anidataloader
from
.
_
pyanitools
import
anidataloader
import
torch
import
torch.utils.data
as
data
import
pickle
from
..
import
utils
from
.
import
utils
def
chunk_counts
(
counts
,
split
):
...
...
@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
class
BatchedANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
Tru
e
,
properties
=
[
'energies'
],
transform
=
(),
def
__init__
(
self
,
path
,
species
_tensor_converter
,
batch_siz
e
,
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
BatchedANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
species
=
species
self
.
species_indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
self
.
properties
=
properties
...
...
@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
for
f
in
files
:
for
m
in
anidataloader
(
f
):
species
=
m
[
'species'
]
indices
=
[
self
.
species_indices
[
i
]
for
i
in
species
]
species
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
long
)
coordinates
=
torch
.
from_numpy
(
m
[
'coordinates'
])
species_coordinates
.
append
((
species
,
coordinates
))
s
=
species_tensor_converter
(
m
[
'species'
])
c
=
torch
.
from_numpy
(
m
[
'coordinates'
]).
to
(
torch
.
double
)
species_coordinates
.
append
((
s
,
c
))
for
i
in
properties
:
properties
[
i
].
append
(
torch
.
from_numpy
(
m
[
i
]))
p
=
torch
.
from_numpy
(
m
[
i
]).
to
(
torch
.
double
)
properties
[
i
].
append
(
p
)
species
,
coordinates
=
utils
.
pad_and_batch
(
species_coordinates
)
for
i
in
properties
:
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
...
...
@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
# convert to desired dtype
species
=
species
coordinates
=
coordinates
.
to
(
dtype
)
properties
=
{
k
:
properties
[
k
].
to
(
dtype
)
for
k
in
properties
}
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
to
(
dtype
)
# split into minibatches, and strip redun
c
ant padding
# split into minibatches, and strip redun
d
ant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
...
...
@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
natoms_batch
=
natoms
[
start
:
end
]
# sort batch by number of atoms to prepare for splitting
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
...
...
@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
def
__len__
(
self
):
return
len
(
self
.
batches
)
def
load_or_create
(
checkpoint
,
batch_size
,
species
,
dataset_path
,
*
args
,
**
kwargs
):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if
not
os
.
path
.
isfile
(
checkpoint
):
full_dataset
=
BatchedANIDataset
(
dataset_path
,
species
,
batch_size
,
*
args
,
**
kwargs
)
training_size
=
int
(
len
(
full_dataset
)
*
0.8
)
validation_size
=
int
(
len
(
full_dataset
)
*
0.1
)
testing_size
=
len
(
full_dataset
)
-
training_size
-
validation_size
lengths
=
[
training_size
,
validation_size
,
testing_size
]
subsets
=
data
.
random_split
(
full_dataset
,
lengths
)
with
open
(
checkpoint
,
'wb'
)
as
f
:
pickle
.
dump
(
subsets
,
f
)
# load dataset from checkpoint file
with
open
(
checkpoint
,
'rb'
)
as
f
:
training
,
validation
,
testing
=
pickle
.
load
(
f
)
return
training
,
validation
,
testing
torchani/
training/loss_metrics
.py
→
torchani/
ignite
.py
View file @
3cced1e6
import
torch
from
.
import
utils
from
torch.nn.modules.loss
import
_Loss
from
ignite.metrics.metric
import
Metric
from
ignite.metrics
import
RootMeanSquaredError
import
torch
class
Container
(
torch
.
nn
.
ModuleDict
):
def
__init__
(
self
,
modules
):
super
(
Container
,
self
).
__init__
(
modules
)
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
}
for
sc
in
species_coordinates
:
for
k
in
self
:
_
,
result
=
self
[
k
](
sc
)
results
[
k
].
append
(
result
)
for
k
in
self
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
return
results
class
DictLoss
(
_Loss
):
...
...
torchani/neurochem.py
View file @
3cced1e6
...
...
@@ -56,9 +56,9 @@ class Constants(Mapping):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
species_to_tensor
(
self
,
species
,
device
):
def
species_to_tensor
(
self
,
species
):
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
,
device
=
device
)
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
)
def
load_sae
(
filename
):
...
...
torchani/training/__init__.py
deleted
100644 → 0
View file @
d7ef8182
from
.container
import
Container
from
.data
import
BatchedANIDataset
,
load_or_create
from
.loss_metrics
import
DictLoss
,
DictMetric
,
MSELoss
,
RMSEMetric
,
\
TransformedLoss
from
.
import
pyanitools
__all__
=
[
'Container'
,
'BatchedANIDataset'
,
'load_or_create'
,
'DictLoss'
,
'DictMetric'
,
'MSELoss'
,
'RMSEMetric'
,
'TransformedLoss'
,
'pyanitools'
]
torchani/training/container.py
deleted
100644 → 0
View file @
d7ef8182
import
torch
from
..
import
utils
class
Container
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
models
):
super
(
Container
,
self
).
__init__
()
self
.
keys
=
models
.
keys
()
for
i
in
models
:
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
.
keys
}
for
sc
in
species_coordinates
:
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
result
=
model
(
sc
)
results
[
k
].
append
(
result
)
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
for
k
in
self
.
keys
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment