Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
1d8bba37
Unverified
Commit
1d8bba37
authored
Aug 02, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 02, 2018
Browse files
add argparser for nnp_training.py (#46)
parent
861580e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
19 deletions
+51
-19
.gitignore
.gitignore
+1
-0
examples/nnp_training.py
examples/nnp_training.py
+50
-19
No files found.
.gitignore
View file @
1d8bba37
...
@@ -15,4 +15,5 @@ a.out
...
@@ -15,4 +15,5 @@ a.out
benchmark_xyz
benchmark_xyz
*.pyc
*.pyc
*checkpoint*
*checkpoint*
*.pt
/runs
/runs
\ No newline at end of file
examples/nnp_training.py
View file @
1d8bba37
import
sys
import
torch
import
torch
import
ignite
import
ignite
import
torchani
import
torchani
...
@@ -7,29 +6,61 @@ import tqdm
...
@@ -7,29 +6,61 @@ import tqdm
import
timeit
import
timeit
import
tensorboardX
import
tensorboardX
import
math
import
math
import
argparse
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
import
json
chunk_size
=
256
# parse command line arguments
batch_chunks
=
4
parser
=
argparse
.
ArgumentParser
()
dataset_path
=
sys
.
argv
[
1
]
parser
.
add_argument
(
'dataset_path'
,
dataset_checkpoint
=
'dataset-checkpoint.dat'
help
=
'Path of the dataset, can a hdf5 file
\
model_checkpoint
=
'checkpoint.pt'
or a directory containing hdf5 files'
)
max_epochs
=
10
parser
.
add_argument
(
'--dataset_checkpoint'
,
help
=
'Checkpoint file for datasets'
,
writer
=
tensorboardX
.
SummaryWriter
()
default
=
'dataset-checkpoint.dat'
)
parser
.
add_argument
(
'--model_checkpoint'
,
help
=
'Checkpoint file for model'
,
default
=
'model.pt'
)
parser
.
add_argument
(
'-m'
,
'--max_epochs'
,
help
=
'Maximum number of epoches'
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
'-d'
,
'--device'
,
help
=
'Device of modules and tensors'
,
default
=
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
))
parser
.
add_argument
(
'--chunk_size'
,
help
=
'Number of conformations of each chunk'
,
default
=
256
,
type
=
int
)
parser
.
add_argument
(
'--batch_chunks'
,
help
=
'Number of chunks in each minibatch'
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
'--log'
,
help
=
'Log directory for tensorboardX'
,
default
=
None
)
parser
.
add_argument
(
'--optimizer'
,
help
=
'Optimizer used to train the model'
,
default
=
'Adam'
)
parser
.
add_argument
(
'--optim_args'
,
help
=
'Arguments to optimizers, in the format of json'
,
default
=
'{}'
)
parser
=
parser
.
parse_args
()
# set up the training
device
=
torch
.
device
(
parser
.
device
)
writer
=
tensorboardX
.
SummaryWriter
(
log_dir
=
parser
.
log
)
start
=
timeit
.
default_timer
()
start
=
timeit
.
default_timer
()
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
dataset_checkpoint
,
dataset_path
,
chunk_size
,
device
=
device
,
parser
.
dataset_checkpoint
,
parser
.
dataset_path
,
parser
.
chunk_size
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
training
=
torchani
.
data
.
dataloader
(
training
,
parser
.
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
parser
.
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
energy_mse_loss
)
container
,
optimizer
,
torchani
.
ignite
.
energy_mse_loss
)
...
@@ -78,4 +109,4 @@ def log_loss_and_time(trainer):
...
@@ -78,4 +109,4 @@ def log_loss_and_time(trainer):
writer
.
add_scalar
(
'training_rmse_vs_iteration'
,
rmse
,
iteration
)
writer
.
add_scalar
(
'training_rmse_vs_iteration'
,
rmse
,
iteration
)
trainer
.
run
(
training
,
max_epochs
=
max_epochs
)
trainer
.
run
(
training
,
max_epochs
=
parser
.
max_epochs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment