Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
83107add
Unverified
Commit
83107add
authored
Sep 06, 2018
by
Gao, Xiang
Committed by
GitHub
Sep 06, 2018
Browse files
clean up flatten in tutorials, make neurochem trainer support aev caching (#92)
parent
e69e59a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
30 deletions
+27
-30
examples/cache_aev.py
examples/cache_aev.py
+1
-6
examples/nnp_training.py
examples/nnp_training.py
+1
-13
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+25
-11
No files found.
examples/cache_aev.py
View file @
83107add
...
@@ -92,15 +92,10 @@ else:
...
@@ -92,15 +92,10 @@ else:
torch
.
save
(
nn
.
state_dict
(),
model_checkpoint
)
torch
.
save
(
nn
.
state_dict
(),
model_checkpoint
)
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
],
x
[
1
].
flatten
()
###############################################################################
###############################################################################
# Except that at here we do not include aev computer into our pipeline, because
# Except that at here we do not include aev computer into our pipeline, because
# the cache loader will load computed AEVs from disk.
# the cache loader will load computed AEVs from disk.
model
=
torch
.
nn
.
Sequential
(
nn
,
Flatten
())
.
to
(
device
)
model
=
nn
.
to
(
device
)
###############################################################################
###############################################################################
# This part is also a line by line copy
# This part is also a line by line copy
...
...
examples/nnp_training.py
View file @
83107add
...
@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint):
...
@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint):
else
:
else
:
torch
.
save
(
nn
.
state_dict
(),
model_checkpoint
)
torch
.
save
(
nn
.
state_dict
(),
model_checkpoint
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
###############################################################################
# The output energy tensor has shape ``(N, 1)`` where ``N`` is the number of
# different structures in each minibatch. However, in the dataset, the label
# has shape ``(N,)``. To make it possible to subtract these two tensors, we
# need to flatten the output tensor.
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
],
x
[
1
].
flatten
()
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
,
Flatten
()).
to
(
device
)
###############################################################################
###############################################################################
# Now setup tensorboardX.
# Now setup tensorboardX.
...
...
torchani/neurochem/__init__.py
View file @
83107add
...
@@ -304,6 +304,7 @@ def hartree2kcal(x):
...
@@ -304,6 +304,7 @@ def hartree2kcal(x):
from
..data
import
BatchedANIDataset
# noqa: E402
from
..data
import
BatchedANIDataset
# noqa: E402
from
..data
import
AEVCacheLoader
# noqa: E402
class
Trainer
:
class
Trainer
:
...
@@ -315,12 +316,14 @@ class Trainer:
...
@@ -315,12 +316,14 @@ class Trainer:
tqdm (bool): whether to enable tqdm
tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to
\
tensorboard (str): Directory to store tensorboard log file, set to
\
``None`` to disable tensorboardX.
``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
"""
"""
def
__init__
(
self
,
filename
,
device
=
torch
.
device
(
'cuda'
),
def
__init__
(
self
,
filename
,
device
=
torch
.
device
(
'cuda'
),
tqdm
=
False
,
tensorboard
=
None
):
tqdm
=
False
,
tensorboard
=
None
,
aev_caching
=
False
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
device
=
device
self
.
device
=
device
self
.
aev_caching
=
aev_caching
if
tqdm
:
if
tqdm
:
import
tqdm
import
tqdm
self
.
tqdm
=
tqdm
.
tqdm
self
.
tqdm
=
tqdm
.
tqdm
...
@@ -528,7 +531,10 @@ class Trainer:
...
@@ -528,7 +531,10 @@ class Trainer:
i
=
o
i
=
o
atomic_nets
[
atom_type
]
=
torch
.
nn
.
Sequential
(
*
modules
)
atomic_nets
[
atom_type
]
=
torch
.
nn
.
Sequential
(
*
modules
)
self
.
model
=
ANIModel
([
atomic_nets
[
s
]
for
s
in
self
.
consts
.
species
])
self
.
model
=
ANIModel
([
atomic_nets
[
s
]
for
s
in
self
.
consts
.
species
])
self
.
nnp
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
model
)
if
self
.
aev_caching
:
self
.
nnp
=
self
.
model
else
:
self
.
nnp
=
torch
.
nn
.
Sequential
(
self
.
aev_computer
,
self
.
model
)
self
.
container
=
Container
({
'energies'
:
self
.
nnp
}).
to
(
self
.
device
)
self
.
container
=
Container
({
'energies'
:
self
.
nnp
}).
to
(
self
.
device
)
# losses
# losses
...
@@ -561,15 +567,23 @@ class Trainer:
...
@@ -561,15 +567,23 @@ class Trainer:
return
hartree2kcal
(
metrics
[
'RMSE'
]),
hartree2kcal
(
metrics
[
'MAE'
])
return
hartree2kcal
(
metrics
[
'RMSE'
]),
hartree2kcal
(
metrics
[
'MAE'
])
def
load_data
(
self
,
training_path
,
validation_path
):
def
load_data
(
self
,
training_path
,
validation_path
):
"""Load training and validation dataset from file"""
"""Load training and validation dataset from file.
self
.
training_set
=
BatchedANIDataset
(
training_path
,
self
.
consts
.
species_to_tensor
,
If AEV caching is enabled, then the arguments are path to the cache
self
.
training_batch_size
,
device
=
self
.
device
,
directory, otherwise it should be path to the dataset.
transform
=
[
self
.
shift_energy
.
subtract_from_dataset
])
"""
self
.
validation_set
=
BatchedANIDataset
(
if
self
.
aev_caching
:
validation_path
,
self
.
consts
.
species_to_tensor
,
self
.
training_set
=
AEVCacheLoader
(
training_path
)
self
.
validation_batch_size
,
device
=
self
.
device
,
self
.
validation_set
=
AEVCacheLoader
(
validation_path
)
transform
=
[
self
.
shift_energy
.
subtract_from_dataset
])
else
:
self
.
training_set
=
BatchedANIDataset
(
training_path
,
self
.
consts
.
species_to_tensor
,
self
.
training_batch_size
,
device
=
self
.
device
,
transform
=
[
self
.
shift_energy
.
subtract_from_dataset
])
self
.
validation_set
=
BatchedANIDataset
(
validation_path
,
self
.
consts
.
species_to_tensor
,
self
.
validation_batch_size
,
device
=
self
.
device
,
transform
=
[
self
.
shift_energy
.
subtract_from_dataset
])
def
run
(
self
):
def
run
(
self
):
"""Run the training"""
"""Run the training"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment