Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
8f1cb0c7
Unverified
Commit
8f1cb0c7
authored
Sep 11, 2018
by
Gao, Xiang
Committed by
GitHub
Sep 11, 2018
Browse files
Allow manually specifying checkpoint filename (#95)
parent
615f8144
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
5 deletions
+12
-5
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+8
-4
torchani/neurochem/trainer.py
torchani/neurochem/trainer.py
+4
-1
No files found.
torchani/neurochem/__init__.py
View file @
8f1cb0c7
...
@@ -314,16 +314,20 @@ class Trainer:
...
@@ -314,16 +314,20 @@ class Trainer:
filename (str): Input file name
filename (str): Input file name
device (:class:`torch.device`): device to train the model
device (:class:`torch.device`): device to train the model
tqdm (bool): whether to enable tqdm
tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to
\
tensorboard (str): Directory to store tensorboard log file, set to
``None`` to disable tensorboardX.
``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
aev_caching (bool): Whether to use AEV caching.
checkpoint_name (str): Name of the checkpoint file, checkpoints will be
stored in the network directory with this file name.
"""
"""
def
__init__
(
self
,
filename
,
device
=
torch
.
device
(
'cuda'
),
def
__init__
(
self
,
filename
,
device
=
torch
.
device
(
'cuda'
),
tqdm
=
False
,
tqdm
=
False
,
tensorboard
=
None
,
aev_caching
=
False
):
tensorboard
=
None
,
aev_caching
=
False
,
checkpoint_name
=
'model.pt'
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
device
=
device
self
.
device
=
device
self
.
aev_caching
=
aev_caching
self
.
aev_caching
=
aev_caching
self
.
checkpoint_name
=
checkpoint_name
if
tqdm
:
if
tqdm
:
import
tqdm
import
tqdm
self
.
tqdm
=
tqdm
.
tqdm
self
.
tqdm
=
tqdm
.
tqdm
...
@@ -475,7 +479,7 @@ class Trainer:
...
@@ -475,7 +479,7 @@ class Trainer:
network_dir
=
os
.
path
.
join
(
dir
,
params
[
'ntwkStoreDir'
])
network_dir
=
os
.
path
.
join
(
dir
,
params
[
'ntwkStoreDir'
])
if
not
os
.
path
.
exists
(
network_dir
):
if
not
os
.
path
.
exists
(
network_dir
):
os
.
makedirs
(
network_dir
)
os
.
makedirs
(
network_dir
)
self
.
model_checkpoint
=
os
.
path
.
join
(
network_dir
,
'model.pt'
)
self
.
model_checkpoint
=
os
.
path
.
join
(
network_dir
,
self
.
checkpoint_name
)
del
params
[
'ntwkStoreDir'
]
del
params
[
'ntwkStoreDir'
]
self
.
max_nonimprove
=
params
[
'tolr'
]
self
.
max_nonimprove
=
params
[
'tolr'
]
del
params
[
'tolr'
]
del
params
[
'tolr'
]
...
...
torchani/neurochem/trainer.py
View file @
8f1cb0c7
...
@@ -28,10 +28,13 @@ if __name__ == '__main__':
...
@@ -28,10 +28,13 @@ if __name__ == '__main__':
default
=
None
)
default
=
None
)
parser
.
add_argument
(
'--cache-aev'
,
dest
=
'cache_aev'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--cache-aev'
,
dest
=
'cache_aev'
,
action
=
'store_true'
,
help
=
'Whether to cache AEV'
,
default
=
None
)
help
=
'Whether to cache AEV'
,
default
=
None
)
parser
.
add_argument
(
'--checkpoint_name'
,
help
=
'Name of checkpoint file'
,
default
=
'model.pt'
)
parser
=
parser
.
parse_args
()
parser
=
parser
.
parse_args
()
d
=
torch
.
device
(
parser
.
device
)
d
=
torch
.
device
(
parser
.
device
)
trainer
=
Trainer
(
parser
.
config_path
,
d
,
parser
.
tqdm
,
parser
.
tensorboard
,
trainer
=
Trainer
(
parser
.
config_path
,
d
,
parser
.
tqdm
,
parser
.
tensorboard
,
parser
.
cache_aev
)
parser
.
cache_aev
,
parser
.
checkpoint_name
)
trainer
.
load_data
(
parser
.
training_path
,
parser
.
validation_path
)
trainer
.
load_data
(
parser
.
training_path
,
parser
.
validation_path
)
trainer
.
run
()
trainer
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment