Unverified Commit 83107add authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

clean up flatten in tutorials, make neurochem trainer support aev caching (#92)

parent e69e59a4
...@@ -92,15 +92,10 @@ else: ...@@ -92,15 +92,10 @@ else:
torch.save(nn.state_dict(), model_checkpoint) torch.save(nn.state_dict(), model_checkpoint)
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
############################################################################### ###############################################################################
# Except that at here we do not include aev computer into our pipeline, because # Except that at here we do not include aev computer into our pipeline, because
# the cache loader will load computed AEVs from disk. # the cache loader will load computed AEVs from disk.
model = torch.nn.Sequential(nn, Flatten()).to(device) model = nn.to(device)
############################################################################### ###############################################################################
# This part is also a line by line copy # This part is also a line by line copy
......
...@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint): ...@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint):
else: else:
torch.save(nn.state_dict(), model_checkpoint) torch.save(nn.state_dict(), model_checkpoint)
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# The output energy tensor has shape ``(N, 1)`` where ``N`` is the number of
# different structures in each minibatch. However, in the dataset, the label
# has shape ``(N,)``. To make it possible to subtract these two tensors, we
# need to flatten the output tensor.
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nn, Flatten()).to(device)
############################################################################### ###############################################################################
# Now setup tensorboardX. # Now setup tensorboardX.
......
...@@ -304,6 +304,7 @@ def hartree2kcal(x): ...@@ -304,6 +304,7 @@ def hartree2kcal(x):
from ..data import BatchedANIDataset # noqa: E402 from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer: class Trainer:
...@@ -315,12 +316,14 @@ class Trainer: ...@@ -315,12 +316,14 @@ class Trainer:
tqdm (bool): whether to enable tqdm tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to\ tensorboard (str): Directory to store tensorboard log file, set to\
``None`` to disable tensorboardX. ``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
""" """
def __init__(self, filename, device=torch.device('cuda'), def __init__(self, filename, device=torch.device('cuda'),
tqdm=False, tensorboard=None): tqdm=False, tensorboard=None, aev_caching=False):
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
...@@ -528,6 +531,9 @@ class Trainer: ...@@ -528,6 +531,9 @@ class Trainer:
i = o i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules) atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] for s in self.consts.species]) self.model = ANIModel([atomic_nets[s] for s in self.consts.species])
if self.aev_caching:
self.nnp = self.model
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model) self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device) self.container = Container({'energies': self.nnp}).to(self.device)
...@@ -561,7 +567,15 @@ class Trainer: ...@@ -561,7 +567,15 @@ class Trainer:
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']) return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file""" """Load training and validation dataset from file.
If AEV caching is enabled, then the arguments are path to the cache
directory, otherwise it should be path to the dataset.
"""
if self.aev_caching:
self.training_set = AEVCacheLoader(training_path)
self.validation_set = AEVCacheLoader(validation_path)
else:
self.training_set = BatchedANIDataset( self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor, training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device, self.training_batch_size, device=self.device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment