Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
3cced1e6
"include/vscode:/vscode.git/clone" did not exist on "f03a1738d93c8ffccc570e8121e0a261e9950fa6"
Unverified
Commit
3cced1e6
authored
Aug 23, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 23, 2018
Browse files
Improve training related API (#76)
parent
d7ef8182
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
98 additions
and
125 deletions
+98
-125
codefresh.yml
codefresh.yml
+3
-3
examples/energy_force.py
examples/energy_force.py
+1
-1
examples/neurochem-test.py
examples/neurochem-test.py
+4
-4
examples/nnp_training.py
examples/nnp_training.py
+37
-25
examples/training-benchmark.py
examples/training-benchmark.py
+4
-4
tests/test_data.py
tests/test_data.py
+4
-5
tests/test_ignite.py
tests/test_ignite.py
+7
-7
torchani/__init__.py
torchani/__init__.py
+3
-2
torchani/_pyanitools.py
torchani/_pyanitools.py
+0
-0
torchani/data.py
torchani/data.py
+13
-38
torchani/ignite.py
torchani/ignite.py
+20
-1
torchani/neurochem.py
torchani/neurochem.py
+2
-2
torchani/training/__init__.py
torchani/training/__init__.py
+0
-9
torchani/training/container.py
torchani/training/container.py
+0
-24
No files found.
codefresh.yml
View file @
3cced1e6
...
@@ -24,9 +24,9 @@ steps:
...
@@ -24,9 +24,9 @@ steps:
Examples
:
Examples
:
image
:
'
${{BuildTorchANI}}'
image
:
'
${{BuildTorchANI}}'
commands
:
commands
:
-
rm -rf
*.dat
*.pt
-
rm -rf *.pt
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
-
python examples/nnp_training.py
./
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/nnp_training.py
dataset/ani_gdb_s01.h5
dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
# run twice to test if checkpoint is working
-
python examples/energy_force.py
-
python examples/energy_force.py
...
...
examples/energy_force.py
View file @
3cced1e6
...
@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...
@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
requires_grad
=
True
)
requires_grad
=
True
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
,
device
).
unsqueeze
(
0
)
species
=
consts
.
species_to_tensor
(
'CHHHH'
).
to
(
device
).
unsqueeze
(
0
)
_
,
energy
=
model
((
species
,
coordinates
))
_
,
energy
=
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
...
...
examples/neurochem-test.py
View file @
3cced1e6
...
@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
...
@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
)
aev_computer
=
torchani
.
AEVComputer
(
**
consts
)
nn
=
torchani
.
neurochem
.
load_model
(
consts
.
species
,
parser
.
network_dir
)
nn
=
torchani
.
neurochem
.
load_model
(
consts
.
species
,
parser
.
network_dir
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
)
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
container
=
container
.
to
(
device
)
container
=
container
.
to
(
device
)
# load datasets
# load datasets
if
parser
.
dataset_path
.
endswith
(
'.h5'
)
or
\
if
parser
.
dataset_path
.
endswith
(
'.h5'
)
or
\
parser
.
dataset_path
.
endswith
(
'.hdf5'
)
or
\
parser
.
dataset_path
.
endswith
(
'.hdf5'
)
or
\
os
.
path
.
isdir
(
parser
.
dataset_path
):
os
.
path
.
isdir
(
parser
.
dataset_path
):
dataset
=
torchani
.
training
.
BatchedANIDataset
(
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
consts
.
species
,
parser
.
batch_size
,
parser
.
dataset_path
,
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
datasets
=
[
dataset
]
datasets
=
[
dataset
]
else
:
else
:
...
@@ -60,7 +60,7 @@ def hartree2kcal(x):
...
@@ -60,7 +60,7 @@ def hartree2kcal(x):
for
dataset
in
datasets
:
for
dataset
in
datasets
:
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
evaluator
.
run
(
dataset
)
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
metrics
=
evaluator
.
state
.
metrics
...
...
examples/nnp_training.py
View file @
3cced1e6
...
@@ -11,18 +11,18 @@ import json
...
@@ -11,18 +11,18 @@ import json
# parse command line arguments
# parse command line arguments
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'dataset_path'
,
parser
.
add_argument
(
'training_path'
,
help
=
'Path of the dataset, can a hdf5 file
\
help
=
'Path of the training set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'validation_path'
,
help
=
'Path of the validation set, can be a hdf5 file
\
or a directory containing hdf5 files'
)
or a directory containing hdf5 files'
)
parser
.
add_argument
(
'--dataset_checkpoint'
,
help
=
'Checkpoint file for datasets'
,
default
=
'dataset-checkpoint.dat'
)
parser
.
add_argument
(
'--model_checkpoint'
,
parser
.
add_argument
(
'--model_checkpoint'
,
help
=
'Checkpoint file for model'
,
help
=
'Checkpoint file for model'
,
default
=
'model.pt'
)
default
=
'model.pt'
)
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
help
=
'Maximum number of epoches'
,
help
=
'Maximum number of epoches'
,
default
=
1
00
,
type
=
int
)
default
=
3
00
,
type
=
int
)
parser
.
add_argument
(
'--training_rmse_every'
,
parser
.
add_argument
(
'--training_rmse_every'
,
help
=
'Compute training RMSE every epoches'
,
help
=
'Compute training RMSE every epoches'
,
default
=
20
,
type
=
int
)
default
=
20
,
type
=
int
)
...
@@ -53,20 +53,24 @@ start = timeit.default_timer()
...
@@ -53,20 +53,24 @@ start = timeit.default_timer()
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
training
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
model
.
consts
.
species
,
parser
.
training_path
,
model
.
consts
.
species_to_tensor
,
parser
.
dataset_path
,
device
=
device
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
validation
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
validation_path
,
model
.
consts
.
species_to_tensor
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
...
@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
...
@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
def
validation_and_checkpoint
(
trainer
):
def
validation_and_checkpoint
(
trainer
):
def
evaluate
(
dataset
,
name
):
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
}
)
evaluator
.
run
(
dataset
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
name
,
rmse
,
trainer
.
state
.
epoch
)
return
rmse
# compute validation RMSE
# compute validation RMSE
evaluator
.
run
(
validation
)
rmse
=
evaluate
(
validation
,
'validation_rmse_vs_epoch'
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'validation_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
# compute training RMSE
# compute training RMSE
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
0
:
if
trainer
.
state
.
epoch
%
parser
.
training_rmse_every
==
1
:
evaluator
.
run
(
training
)
evaluate
(
training
,
'training_rmse_vs_epoch'
)
metrics
=
evaluator
.
state
.
metrics
rmse
=
hartree2kcal
(
metrics
[
'RMSE'
])
writer
.
add_scalar
(
'training_rmse_vs_epoch'
,
rmse
,
trainer
.
state
.
epoch
)
# handle best validation RMSE
# handle best validation RMSE
if
rmse
<
trainer
.
state
.
best_validation_rmse
:
if
rmse
<
trainer
.
state
.
best_validation_rmse
:
...
@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
...
@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
torch
.
save
(
nnp
.
state_dict
(),
parser
.
model_checkpoint
)
torch
.
save
(
nnp
.
state_dict
(),
parser
.
model_checkpoint
)
else
:
else
:
trainer
.
state
.
no_improve_count
+=
1
trainer
.
state
.
no_improve_count
+=
1
writer
.
add_scalar
(
'no_improve_count_vs_epoch'
,
trainer
.
state
.
no_improve_count
,
trainer
.
state
.
epoch
)
if
trainer
.
state
.
no_improve_count
>
parser
.
early_stopping
:
if
trainer
.
state
.
no_improve_count
>
parser
.
early_stopping
:
trainer
.
terminate
()
trainer
.
terminate
()
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
@@ -134,8 +147,7 @@ def log_time(trainer):
...
@@ -134,8 +147,7 @@ def log_time(trainer):
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
ITERATION_COMPLETED
)
def
log_loss_and_time
(
trainer
):
def
log_loss_and_time
(
trainer
):
iteration
=
trainer
.
state
.
iteration
iteration
=
trainer
.
state
.
iteration
rmse
=
hartree2kcal
(
math
.
sqrt
(
trainer
.
state
.
output
))
writer
.
add_scalar
(
'loss_vs_iteration'
,
trainer
.
state
.
output
,
iteration
)
writer
.
add_scalar
(
'training_atomic_rmse_vs_iteration'
,
rmse
,
iteration
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
examples/training-benchmark.py
View file @
3cced1e6
...
@@ -23,15 +23,15 @@ parser = parser.parse_args()
...
@@ -23,15 +23,15 @@ parser = parser.parse_args()
device
=
torch
.
device
(
parser
.
device
)
device
=
torch
.
device
(
parser
.
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
dataset
=
torchani
.
training
.
BatchedANIDataset
(
dataset
=
torchani
.
data
.
BatchedANIDataset
(
parser
.
dataset_path
,
model
.
consts
.
species
,
parser
.
dataset_path
,
model
.
consts
.
species
_to_tensor
,
parser
.
batch_size
,
device
=
device
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
training
.
MSELoss
(
'energies'
))
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
))
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
...
tests/test_data.py
View file @
3cced1e6
...
@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
...
@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
class
TestData
(
unittest
.
TestCase
):
class
TestData
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
ds
=
torchani
.
training
.
BatchedANIDataset
(
dataset_path
,
self
.
ds
=
torchani
.
data
.
BatchedANIDataset
(
dataset_path
,
consts
.
species
,
consts
.
species
_to_tensor
,
batch_size
)
batch_size
)
def
_assertTensorEqual
(
self
,
t1
,
t2
):
def
_assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
self
.
assertEqual
((
t1
-
t2
).
abs
().
max
(),
0
)
...
@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
...
@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
(
species3
,
coordinates3
),
(
species3
,
coordinates3
),
])
])
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
chunks
=
torchani
.
training
.
data
.
split_batch
(
natoms
,
species
,
chunks
=
torchani
.
data
.
split_batch
(
natoms
,
species
,
coordinates
)
coordinates
)
start
=
0
start
=
0
last
=
None
last
=
None
for
s
,
c
in
chunks
:
for
s
,
c
in
chunks
:
...
...
tests/test_ignite.py
View file @
3cced1e6
...
@@ -5,7 +5,7 @@ import copy
...
@@ -5,7 +5,7 @@ import copy
from
ignite.engine
import
create_supervised_trainer
,
\
from
ignite.engine
import
create_supervised_trainer
,
\
create_supervised_evaluator
,
Events
create_supervised_evaluator
,
Events
import
torchani
import
torchani
import
torchani.
training
import
torchani.
ignite
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
...
@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
...
@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
aev_computer
=
torchani
.
buildins
.
aev_computer
aev_computer
=
torchani
.
buildins
.
aev_computer
nnp
=
copy
.
deepcopy
(
torchani
.
buildins
.
models
[
0
])
nnp
=
copy
.
deepcopy
(
torchani
.
buildins
.
models
[
0
])
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
ds
=
torchani
.
training
.
BatchedANIDataset
(
ds
=
torchani
.
data
.
BatchedANIDataset
(
path
,
torchani
.
buildins
.
consts
.
species
,
batchsize
,
path
,
torchani
.
buildins
.
consts
.
species
_to_tensor
,
batchsize
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
...
@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
...
@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
return
x
[
0
],
x
[
1
].
flatten
()
return
x
[
0
],
x
[
1
].
flatten
()
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nnp
,
Flatten
())
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nnp
,
Flatten
())
container
=
torchani
.
training
.
Container
({
'energies'
:
model
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
loss
=
torchani
.
training
.
TransformedLoss
(
loss
=
torchani
.
ignite
.
TransformedLoss
(
torchani
.
training
.
MSELoss
(
'energies'
),
torchani
.
ignite
.
MSELoss
(
'energies'
),
lambda
x
:
torch
.
exp
(
x
)
-
1
)
lambda
x
:
torch
.
exp
(
x
)
-
1
)
trainer
=
create_supervised_trainer
(
trainer
=
create_supervised_trainer
(
container
,
optimizer
,
loss
)
container
,
optimizer
,
loss
)
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
training
.
RMSEMetric
(
'energies'
)
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
})
@
trainer
.
on
(
Events
.
COMPLETED
)
@
trainer
.
on
(
Events
.
COMPLETED
)
...
...
torchani/__init__.py
View file @
3cced1e6
from
.utils
import
EnergyShifter
from
.utils
import
EnergyShifter
from
.models
import
ANIModel
,
Ensemble
from
.models
import
ANIModel
,
Ensemble
from
.aev
import
AEVComputer
from
.aev
import
AEVComputer
from
.
import
training
from
.
import
ignite
from
.
import
utils
from
.
import
utils
from
.
import
neurochem
from
.
import
neurochem
from
.
import
data
from
.neurochem
import
buildins
from
.neurochem
import
buildins
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'buildins'
,
__all__
=
[
'AEVComputer'
,
'EnergyShifter'
,
'ANIModel'
,
'Ensemble'
,
'buildins'
,
'
training
'
,
'utils'
,
'neurochem'
]
'
ignite
'
,
'utils'
,
'neurochem'
,
'data'
]
torchani/
training/
pyanitools.py
→
torchani/
_
pyanitools.py
View file @
3cced1e6
File moved
torchani/
training/
data.py
→
torchani/data.py
View file @
3cced1e6
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
os.path
import
join
,
isfile
,
isdir
from
os.path
import
join
,
isfile
,
isdir
import
os
import
os
from
.pyanitools
import
anidataloader
from
.
_
pyanitools
import
anidataloader
import
torch
import
torch
import
torch.utils.data
as
data
from
.
import
utils
import
pickle
from
..
import
utils
def
chunk_counts
(
counts
,
split
):
def
chunk_counts
(
counts
,
split
):
...
@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
...
@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
class
BatchedANIDataset
(
Dataset
):
class
BatchedANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
species
,
batch_size
,
shuffle
=
Tru
e
,
def
__init__
(
self
,
path
,
species
_tensor_converter
,
batch_siz
e
,
properties
=
[
'energies'
],
transform
=
(),
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
BatchedANIDataset
,
self
).
__init__
()
super
(
BatchedANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
path
=
path
self
.
species
=
species
self
.
species_indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
properties
=
properties
self
.
properties
=
properties
...
@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
...
@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
properties
=
{
k
:
[]
for
k
in
self
.
properties
}
for
f
in
files
:
for
f
in
files
:
for
m
in
anidataloader
(
f
):
for
m
in
anidataloader
(
f
):
species
=
m
[
'species'
]
s
=
species_tensor_converter
(
m
[
'species'
])
indices
=
[
self
.
species_indices
[
i
]
for
i
in
species
]
c
=
torch
.
from_numpy
(
m
[
'coordinates'
]).
to
(
torch
.
double
)
species
=
torch
.
tensor
(
indices
,
dtype
=
torch
.
long
)
species_coordinates
.
append
((
s
,
c
))
coordinates
=
torch
.
from_numpy
(
m
[
'coordinates'
])
species_coordinates
.
append
((
species
,
coordinates
))
for
i
in
properties
:
for
i
in
properties
:
properties
[
i
].
append
(
torch
.
from_numpy
(
m
[
i
]))
p
=
torch
.
from_numpy
(
m
[
i
]).
to
(
torch
.
double
)
properties
[
i
].
append
(
p
)
species
,
coordinates
=
utils
.
pad_and_batch
(
species_coordinates
)
species
,
coordinates
=
utils
.
pad_and_batch
(
species_coordinates
)
for
i
in
properties
:
for
i
in
properties
:
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
properties
[
i
]
=
torch
.
cat
(
properties
[
i
])
...
@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
...
@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
# convert to desired dtype
# convert to desired dtype
species
=
species
species
=
species
coordinates
=
coordinates
.
to
(
dtype
)
coordinates
=
coordinates
.
to
(
dtype
)
properties
=
{
k
:
properties
[
k
].
to
(
dtype
)
for
k
in
properties
}
for
k
in
properties
:
properties
[
k
]
=
properties
[
k
].
to
(
dtype
)
# split into minibatches, and strip redun
c
ant padding
# split into minibatches, and strip redun
d
ant padding
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
natoms
=
(
species
>=
0
).
to
(
torch
.
long
).
sum
(
1
)
batches
=
[]
batches
=
[]
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
num_batches
=
(
conformations
+
batch_size
-
1
)
//
batch_size
...
@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
...
@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
start
=
i
*
batch_size
start
=
i
*
batch_size
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
end
=
min
((
i
+
1
)
*
batch_size
,
conformations
)
natoms_batch
=
natoms
[
start
:
end
]
natoms_batch
=
natoms
[
start
:
end
]
# sort batch by number of atoms to prepare for splitting
natoms_batch
,
indices
=
natoms_batch
.
sort
()
natoms_batch
,
indices
=
natoms_batch
.
sort
()
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
species_batch
=
species
[
start
:
end
,
...].
index_select
(
0
,
indices
)
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
coordinates_batch
=
coordinates
[
start
:
end
,
...]
\
...
@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
...
@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
batches
)
return
len
(
self
.
batches
)
def
load_or_create
(
checkpoint
,
batch_size
,
species
,
dataset_path
,
*
args
,
**
kwargs
):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if
not
os
.
path
.
isfile
(
checkpoint
):
full_dataset
=
BatchedANIDataset
(
dataset_path
,
species
,
batch_size
,
*
args
,
**
kwargs
)
training_size
=
int
(
len
(
full_dataset
)
*
0.8
)
validation_size
=
int
(
len
(
full_dataset
)
*
0.1
)
testing_size
=
len
(
full_dataset
)
-
training_size
-
validation_size
lengths
=
[
training_size
,
validation_size
,
testing_size
]
subsets
=
data
.
random_split
(
full_dataset
,
lengths
)
with
open
(
checkpoint
,
'wb'
)
as
f
:
pickle
.
dump
(
subsets
,
f
)
# load dataset from checkpoint file
with
open
(
checkpoint
,
'rb'
)
as
f
:
training
,
validation
,
testing
=
pickle
.
load
(
f
)
return
training
,
validation
,
testing
torchani/
training/loss_metrics
.py
→
torchani/
ignite
.py
View file @
3cced1e6
import
torch
from
.
import
utils
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
ignite.metrics.metric
import
Metric
from
ignite.metrics.metric
import
Metric
from
ignite.metrics
import
RootMeanSquaredError
from
ignite.metrics
import
RootMeanSquaredError
import
torch
class
Container
(
torch
.
nn
.
ModuleDict
):
def
__init__
(
self
,
modules
):
super
(
Container
,
self
).
__init__
(
modules
)
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
}
for
sc
in
species_coordinates
:
for
k
in
self
:
_
,
result
=
self
[
k
](
sc
)
results
[
k
].
append
(
result
)
for
k
in
self
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
return
results
class
DictLoss
(
_Loss
):
class
DictLoss
(
_Loss
):
...
...
torchani/neurochem.py
View file @
3cced1e6
...
@@ -56,9 +56,9 @@ class Constants(Mapping):
...
@@ -56,9 +56,9 @@ class Constants(Mapping):
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
def
species_to_tensor
(
self
,
species
,
device
):
def
species_to_tensor
(
self
,
species
):
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
rev
=
[
self
.
rev_species
[
s
]
for
s
in
species
]
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
,
device
=
device
)
return
torch
.
tensor
(
rev
,
dtype
=
torch
.
long
)
def
load_sae
(
filename
):
def
load_sae
(
filename
):
...
...
torchani/training/__init__.py
deleted
100644 → 0
View file @
d7ef8182
from
.container
import
Container
from
.data
import
BatchedANIDataset
,
load_or_create
from
.loss_metrics
import
DictLoss
,
DictMetric
,
MSELoss
,
RMSEMetric
,
\
TransformedLoss
from
.
import
pyanitools
__all__
=
[
'Container'
,
'BatchedANIDataset'
,
'load_or_create'
,
'DictLoss'
,
'DictMetric'
,
'MSELoss'
,
'RMSEMetric'
,
'TransformedLoss'
,
'pyanitools'
]
torchani/training/container.py
deleted
100644 → 0
View file @
d7ef8182
import
torch
from
..
import
utils
class
Container
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
models
):
super
(
Container
,
self
).
__init__
()
self
.
keys
=
models
.
keys
()
for
i
in
models
:
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
species_coordinates
):
results
=
{
k
:
[]
for
k
in
self
.
keys
}
for
sc
in
species_coordinates
:
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
result
=
model
(
sc
)
results
[
k
].
append
(
result
)
results
[
'species'
],
results
[
'coordinates'
]
=
\
utils
.
pad_and_batch
(
species_coordinates
)
for
k
in
self
.
keys
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment