Unverified Commit 83107add authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

clean up flatten in tutorials, make neurochem trainer support aev caching (#92)

parent e69e59a4
......@@ -92,15 +92,10 @@ else:
torch.save(nn.state_dict(), model_checkpoint)
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
###############################################################################
# Except that at here we do not include aev computer into our pipeline, because
# the cache loader will load computed AEVs from disk.
model = torch.nn.Sequential(nn, Flatten()).to(device)
model = nn.to(device)
###############################################################################
# This part is also a line by line copy
......
......@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint):
else:
torch.save(nn.state_dict(), model_checkpoint)
###############################################################################
# The output energy tensor has shape ``(N, 1)`` where ``N`` is the number of
# different structures in each minibatch. However, in the dataset, the label
# has shape ``(N,)``. To make it possible to subtract these two tensors, we
# need to flatten the output tensor.
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nn, Flatten()).to(device)
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now setup tensorboardX.
......
......@@ -304,6 +304,7 @@ def hartree2kcal(x):
from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer:
......@@ -315,12 +316,14 @@ class Trainer:
tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to\
``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
"""
def __init__(self, filename, device=torch.device('cuda'),
tqdm=False, tensorboard=None):
tqdm=False, tensorboard=None, aev_caching=False):
self.filename = filename
self.device = device
self.aev_caching = aev_caching
if tqdm:
import tqdm
self.tqdm = tqdm.tqdm
......@@ -528,7 +531,10 @@ class Trainer:
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] for s in self.consts.species])
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
if self.aev_caching:
self.nnp = self.model
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device)
# losses
......@@ -561,15 +567,23 @@ class Trainer:
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file"""
self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = BatchedANIDataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
"""Load training and validation dataset from file.
If AEV caching is enabled, then the arguments are path to the cache
directory, otherwise it should be path to the dataset.
"""
if self.aev_caching:
self.training_set = AEVCacheLoader(training_path)
self.validation_set = AEVCacheLoader(validation_path)
else:
self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = BatchedANIDataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def run(self):
"""Run the training"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment