Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
3cced1e6
Unverified
Commit
3cced1e6
authored
Aug 23, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 23, 2018
Browse files
Improve training related API (#76)
parent
d7ef8182
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
98 additions
and
125 deletions
+98
-125
codefresh.yml
codefresh.yml
+3
-3
examples/energy_force.py
examples/energy_force.py
+1
-1
examples/neurochem-test.py
examples/neurochem-test.py
+4
-4
examples/nnp_training.py
examples/nnp_training.py
+37
-25
examples/training-benchmark.py
examples/training-benchmark.py
+4
-4
tests/test_data.py
tests/test_data.py
+4
-5
tests/test_ignite.py
tests/test_ignite.py
+7
-7
torchani/__init__.py
torchani/__init__.py
+3
-2
torchani/_pyanitools.py
torchani/_pyanitools.py
+0
-0
torchani/data.py
torchani/data.py
+13
-38
torchani/ignite.py
torchani/ignite.py
+20
-1
torchani/neurochem.py
torchani/neurochem.py
+2
-2
torchani/training/__init__.py
torchani/training/__init__.py
+0
-9
torchani/training/container.py
torchani/training/container.py
+0
-24
No files found.
codefresh.yml
View file @
3cced1e6
...
@@ -24,9 +24,9 @@ steps:
...
@@ -24,9 +24,9 @@ steps:
Examples
:
Examples
:
image
:
'
${{BuildTorchANI}}'
image
:
'
${{BuildTorchANI}}'
commands
:
commands
:
-
rm -rf
*.dat
*.pt
-
rm -rf *.pt
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/energy_force.py
-
python examples/energy_force.py
...
...
examples/energy_force.py
View file @
3cced1e6
...
@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...
@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
requires_grad
=
True
)
requires_grad
=
True
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
,
device
).
unsqueeze
(
0
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
).
to
(
device
).
unsqueeze
(
0
)
_
,
energy
=
model
((
species
,
coordinates
))
_
,
energy
=
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
...
...
examples/neurochem-test.py
View file @
3cced1e6
...
@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
...
@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
)
nn
=
torchani
.
neurochem
.
load_model
(
consts
.
species
,
parser
.
network_dir
)
nn
=
torchani
.
neurochem
.
load_model
(
consts
.
species
,
parser
.
network_dir
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
)
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
container
=
container
.
to
(
device
)
container
=
container
.
to
(
device
)
# load datasets
# load datasets
if
parser
.
dataset_path
.
endswith
(
'.h5'
)
or
\
if
parser
.
dataset_path
.
endswith
(
'.h5'
)
or
\
parser
.
dataset_path
.
endswith
(
'.hdf5'
)
or
\
parser
.
dataset_path
.
endswith
(
'.hdf5'
)
or
\
os
.
path
.
isdir
(
parser
.
dataset_path
):
os
.
path
.
isdir
(
parser
.
dataset_path
):
dataset
=
torchani
.
training
.
BatchedANIDataset
(
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
consts
.
species
,
parser
.
batch_size
,
parser
.
dataset_path
,
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
datasets
=
[
dataset
]
datasets
=
[
dataset
]
else
:
else
:
...
@@ -60,7 +60,7 @@ def hartree2kcal(x):
...
@@ -60,7 +60,7 @@ def hartree2kcal(x):
for
dataset
in
datasets
:
for
dataset
in
datasets
:
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
evaluator
.
run
(
dataset
)
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
metrics
=
evaluator
.
state
.
metrics
...
...
examples/nnp_training.py
View file @
3cced1e6
...
@@ -11,18 +11,18 @@ import json
...
@@ -11,18 +11,18 @@ import json
# parse command line arguments
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset_path'
,
parser
.
add_argument
(
'training_path'
,
help
=
'Path of the dataset, can a hdf5 file
\
help
=
'Path of the training set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'validation_path'
,
help
=
'Path of the validation set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'--dataset_checkpoint'
,
help
=
'Checkpoint file for datasets'
,
default
=
'dataset-checkpoint.dat'
)
parser
.
add_argument
(
'--model_checkpoint'
,
parser
.
add_argument
(
'--model_checkpoint'
,
help
=
'Checkpoint file for model'
,
help
=
'Checkpoint file for model'
,
default
=
'model.pt'
)
default
=
'model.pt'
)
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
help
=
'Maximum number of epoches'
,
help
=
'Maximum number of epoches'
,
default
=
1
00
,
type
=
int
)
default
=
3
00
,
type
=
int
)
parser
.
add_argument
(
'--training_rmse_every'
,
parser
.
add_argument
(
'--training_rmse_every'
,
help
=
'Compute training RMSE every epoches'
,
help
=
'Compute training RMSE every epoches'
,
default
=
20
,
type
=
int
)
default
=
20
,
type
=
int
)
...
@@ -53,20 +53,24 @@ start = timeit.default_timer()
...
@@ -53,20 +53,24 @@ start = timeit.default_timer()
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
training
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
model
.
consts
.
species
,
parser
.
training_path
,
model
.
consts
.
species_to_tensor
,
parser
.
dataset_path
,
device
=
device
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
validation
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
validation_path
,
model
.
consts
.
species_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
...
@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
...
@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
validation_and_checkpoint
(
trainer
):
def
validation_and_checkpoint
(
trainer
):
def
evaluate
(
dataset
,
name
):
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
}
)
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
name
,
rmse
,
trainer
.
state
.
epoch
)
return
rmse
# compute validation RMSE
# compute validation RMSE
evaluator
.
run
(
validation
)
rmse
=
evaluate
(
validation
,
'validation_rmse_vs_epoch'
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'validation_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
# compute training RMSE
# compute training RMSE
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
0
:
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
1
:
evaluator
.
run
(
training
)
evaluate
(
training
,
'training_rmse_vs_epoch'
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'training_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
# handle best validation RMSE
# handle best validation RMSE
if
rmse
<
trainer
.
state
.
best_validation_rmse
:
if
rmse
<
trainer
.
state
.
best_validation_rmse
:
...
@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
...
@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
torch
.
save
(
nnp
.
state_dict
(),
parser
.
model_checkpoint
)
torch
.
save
(
nnp
.
state_dict
(),
parser
.
model_checkpoint
)
else
:
else
:
trainer
.
state
.
no_improve_count
+=
1
trainer
.
state
.
no_improve_count
+=
1
writer
.
add_scalar
(
'no_improve_count_vs_epoch'
,
trainer
.
state
.
no_improve_count
,
trainer
.
state
.
epoch
)
if
trainer
.
state
.
no_improve_count
>
parser
.
early_stopping
:
if
trainer
.
state
.
no_improve_count
>
parser
.
early_stopping
:
trainer
.
terminate
()
trainer
.
terminate
()
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
@@ -134,8 +147,7 @@ def log_time(trainer):
...
@@ -134,8 +147,7 @@ def log_time(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
log_loss_and_time
(
trainer
):
def
log_loss_and_time
(
trainer
):
iteration
=
trainer
.
state
.
iteration
iteration
=
trainer
.
state
.
iteration
rmse
=
hartree2kcal
(
math
.
sqrt
(
trainer
.
state
.
output
))
writer
.
add_scalar
(
'loss_vs_iteration'
,
trainer
.
state
.
output
,
iteration
)
writer
.
add_scalar
(
'training_atomic_rmse_vs_iteration'
,
rmse
,
iteration
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
examples/training-benchmark.py
View file @
3cced1e6
...
@@ -23,15 +23,15 @@ parser = parser.parse_args()
...
@@ -23,15 +23,15 @@ parser = parser.parse_args()
device
=
torch
.
device
(
parser
.
device
)
device
=
torch
.
device
(
parser
.
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
dataset
=
torchani
.
training
.
BatchedANIDataset
(
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
model
.
consts
.
species
,
parser
.
dataset_path
,
model
.
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
...
tests/test_data.py
View file @
3cced1e6
...
@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
...
@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
class
TestData
(
unittest
.
TestCase
):
class
TestData
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
ds
=
torchani
.
training
.
BatchedANIDataset
(
dataset_path
,
self
.
ds
=
torchani
.
data
.
BatchedANIDataset
(
dataset_path
,
consts
.
species
,
consts
.
species
_to_tensor
,
batch_size
)
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
...
@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
...
@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
(
species3
,
coordinates3
),
(
species3
,
coordinates3
),
])
])
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
training
.
data
.
split_batch
(
natoms
,
species
,
chunks
=
torchani
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
coordinates
)
start
=
0
start
=
0
last
=
None
last
=
None
for
s
,
c
in
chunks
:
for
s
,
c
in
chunks
:
...
...
tests/test_ignite.py
View file @
3cced1e6
...
@@ -5,7 +5,7 @@ import copy
...
@@ -5,7 +5,7 @@ import copy
from
ignite.engine
import
create_supervised_trainer
,
\
from
ignite.engine
import
create_supervised_trainer
,
\
create_supervised_evaluator
,
Events
create_supervised_evaluator
,
Events
import
torchani
import
torchani
import
torchani.
training
import
torchani.
ignite
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
...
@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
...
@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
aev_computer
=
torchani
.
buildins
.
aev_computer
aev_computer
=
torchani
.
buildins
.
aev_computer
nnp
=
copy
.
deepcopy
(
torchani
.
buildins
.
models
[
0
])
nnp
=
copy
.
deepcopy
(
torchani
.
buildins
.
models
[
0
])
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
ds
=
torchani
.
training
.
BatchedANIDataset
(
ds
=
torchani
.
data
.
BatchedANIDataset
(
path
,
torchani
.
buildins
.
consts
.
species
,
batchsize
,
path
,
torchani
.
buildins
.
consts
.
species
_to_tensor
,
batchsize
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
...
@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
...
@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
return
x
[
0
],
x
[
1
].
flatten
()
return
x
[
0
],
x
[
1
].
flatten
()
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nnp
,
Flatten
())
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nnp
,
Flatten
())
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
loss
=
torchani
.
training
.
TransformedLoss
(
loss
=
torchani
.
ignite
.
TransformedLoss
(
torchani
.
training
.
MSELoss
(
'energies'
),
torchani
.
ignite
.
MSELoss
(
'energies'
),
lambda
x
:
torch
.
exp
(
x
)
-
1
)
lambda
x
:
torch
.
exp
(
x
)
-
1
)
trainer
=
create_supervised_trainer
(
trainer
=
create_supervised_trainer
(
container
,
optimizer
,
loss
)
container
,
optimizer
,
loss
)
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
@
trainer
.
on
(
Events
.
COMPLETED
)
@
trainer
.
on
(
Events
.
COMPLETED
)
...
...
torchani/__init__.py
View file @
3cced1e6
from
.utils
import
EnergyShifter
from
.utils
import
EnergyShifter
from
.models
import
ANIModel
,
Ensemble
from
.models
import
ANIModel
,
Ensemble
from
.aev
import
AEVComputer
from
.aev
import
AEVComputer
from
.
import
training
from
.
import
ignite
from
.
import
utils
from
.
import
utils
from
.
import
neurochem
from
.
import
neurochem
from
.
import
data
from
.neurochem
import
buildins
from
.neurochem
import
buildins
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'buildins'
,
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'buildins'
,
'
training
'
,
'utils'
,
'neurochem'
]
'
ignite
'
,
'utils'
,
'neurochem'
,
'data'
]
torchani/
training/
pyanitools.py
→
torchani/
_
pyanitools.py
View file @
3cced1e6
File moved
torchani/
training/
data.py
→
torchani/data.py
View file @
3cced1e6
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
os.path
import
join
,
isfile
,
isdir
from
os.path
import
join
,
isfile
,
isdir
import
os
import
os
from
.pyanitools
import
anidataloader
from
.
_
pyanitools
import
anidataloader
import
torch
import
torch
import
torch.utils.data
as
data
from
.
import
utils
import
pickle
from
..
import
utils
def
chunk_counts
(
counts
,
split
):
def
chunk_counts
(
counts
,
split
):
...
@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
...
@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
class
BatchedANIDataset
(
Dataset
):
class
BatchedANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
Tru
e
,
def
__init__
(
self
,
path
,
species
_tensor_converter
,
batch_siz
e
,
properties
=
[
'energies'
],
transform
=
(),
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
BatchedANIDataset
,
self
).
__init__
()
super
(
BatchedANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
path
=
path
self
.
species
=
species
self
.
species_indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
properties
=
properties
self
.
properties
=
properties
...
@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
...
@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
for
f
in
files
:
for
f
in
files
:
for
m
in
anidataloader
(
f
):
for
m
in
anidataloader
(
f
):
species
=
m
[
'species'
]
s
=
species_tensor_converter
(
m
[
'species'
])
indices
=
[
self
.
species_indices
[
i
]
for
i
in
species
]
c
=
torch
.
from_numpy
(
m
[
'coordinates'
]).
to
(
torch
.
double
)
species
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
long
)
species_coordinates
.
append
((
s
,
c
))
coordinates
=
torch
.
from_numpy
(
m
[
'coordinates'
])
species_coordinates
.
append
((
species
,
coordinates
))
for
i
in
properties
:
for
i
in
properties
:
properties
[
i
].
append
(
torch
.
from_numpy
(
m
[
i
]))
p
=
torch
.
from_numpy
(
m
[
i
]).
to
(
torch
.
double
)
properties
[
i
].
append
(
p
)
species
,
coordinates
=
utils
.
pad_and_batch
(
species_coordinates
)
species
,
coordinates
=
utils
.
pad_and_batch
(
species_coordinates
)
for
i
in
properties
:
for
i
in
properties
:
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
...
@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
...
@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
# convert to desired dtype
# convert to desired dtype
species
=
species
species
=
species
coordinates
=
coordinates
.
to
(
dtype
)
coordinates
=
coordinates
.
to
(
dtype
)
properties
=
{
k
:
properties
[
k
].
to
(
dtype
)
for
k
in
properties
}
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
to
(
dtype
)
# split into minibatches, and strip redun
c
ant padding
# split into minibatches, and strip redun
d
ant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
...
@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
...
@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
start
=
i
*
batch_size
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
natoms_batch
=
natoms
[
start
:
end
]
natoms_batch
=
natoms
[
start
:
end
]
# sort batch by number of atoms to prepare for splitting
natoms_batch
,
indices
=
natoms_batch
.
sort
()
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
...
@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
...
@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
batches
)
return
len
(
self
.
batches
)
def
load_or_create
(
checkpoint
,
batch_size
,
species
,
dataset_path
,
*
args
,
**
kwargs
):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if
not
os
.
path
.
isfile
(
checkpoint
):
full_dataset
=
BatchedANIDataset
(
dataset_path
,
species
,
batch_size
,
*
args
,
**
kwargs
)
training_size
=
int
(
len
(
full_dataset
)
*
0.8
)
validation_size
=
int
(
len
(
full_dataset
)
*
0.1
)
testing_size
=
len
(
full_dataset
)
-
training_size
-
validation_size
lengths
=
[
training_size
,
validation_size
,
testing_size
]
subsets
=
data
.
random_split
(
full_dataset
,
lengths
)
with
open
(
checkpoint
,
'wb'
)
as
f
:
pickle
.
dump
(
subsets
,
f
)
# load dataset from checkpoint file
with
open
(
checkpoint
,
'rb'
)
as
f
:
training
,
validation
,
testing
=
pickle
.
load
(
f
)
return
training
,
validation
,
testing
torchani/
training/loss_metrics
.py
→
torchani/
ignite
.py
View file @
3cced1e6
import
torch
from
.
import
utils
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
ignite.metrics.metric
import
Metric
from
ignite.metrics.metric
import
Metric
from
ignite.metrics
import
RootMeanSquaredError
from
ignite.metrics
import
RootMeanSquaredError
import
torch
class
Container
(
torch
.
nn
.
ModuleDict
):
def
__init__
(
self
,
modules
):
super
(
Container
,
self
).
__init__
(
modules
)
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
}
for
sc
in
species_coordinates
:
for
k
in
self
:
_
,
result
=
self
[
k
](
sc
)
results
[
k
].
append
(
result
)
for
k
in
self
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
return
results
class
DictLoss
(
_Loss
):
class
DictLoss
(
_Loss
):
...
...
torchani/neurochem.py
View file @
3cced1e6
...
@@ -56,9 +56,9 @@ class Constants(Mapping):
...
@@ -56,9 +56,9 @@ class Constants(Mapping):
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
def
species_to_tensor
(
self
,
species
,
device
):
def
species_to_tensor
(
self
,
species
):
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
,
device
=
device
)
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
)
def
load_sae
(
filename
):
def
load_sae
(
filename
):
...
...
torchani/training/__init__.py
deleted
100644 → 0
View file @
d7ef8182
from
.container
import
Container
from
.data
import
BatchedANIDataset
,
load_or_create
from
.loss_metrics
import
DictLoss
,
DictMetric
,
MSELoss
,
RMSEMetric
,
\
TransformedLoss
from
.
import
pyanitools
__all__
=
[
'Container'
,
'BatchedANIDataset'
,
'load_or_create'
,
'DictLoss'
,
'DictMetric'
,
'MSELoss'
,
'RMSEMetric'
,
'TransformedLoss'
,
'pyanitools'
]
torchani/training/container.py
deleted
100644 → 0
View file @
d7ef8182
import
torch
from
..
import
utils
class
Container
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
models
):
super
(
Container
,
self
).
__init__
()
self
.
keys
=
models
.
keys
()
for
i
in
models
:
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
.
keys
}
for
sc
in
species_coordinates
:
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
result
=
model
(
sc
)
results
[
k
].
append
(
result
)
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
for
k
in
self
.
keys
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment