Unverified Commit 83107add authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

clean up flatten in tutorials, make neurochem trainer support aev caching (#92)

parent e69e59a4
...@@ -92,15 +92,10 @@ else: ...@@ -92,15 +92,10 @@ else:
torch.save(nn.state_dict(), model_checkpoint) torch.save(nn.state_dict(), model_checkpoint)
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
############################################################################### ###############################################################################
# Except that at here we do not include aev computer into our pipeline, because # Except that at here we do not include aev computer into our pipeline, because
# the cache loader will load computed AEVs from disk. # the cache loader will load computed AEVs from disk.
model = torch.nn.Sequential(nn, Flatten()).to(device) model = nn.to(device)
############################################################################### ###############################################################################
# This part is also a line by line copy # This part is also a line by line copy
......
...@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint): ...@@ -95,19 +95,7 @@ if os.path.isfile(model_checkpoint):
else: else:
torch.save(nn.state_dict(), model_checkpoint) torch.save(nn.state_dict(), model_checkpoint)
model = torch.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# The output energy tensor has shape ``(N, 1)`` where ``N`` is the number of
# different structures in each minibatch. However, in the dataset, the label
# has shape ``(N,)``. To make it possible to subtract these two tensors, we
# need to flatten the output tensor.
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nn, Flatten()).to(device)
############################################################################### ###############################################################################
# Now setup tensorboardX. # Now setup tensorboardX.
......
...@@ -304,6 +304,7 @@ def hartree2kcal(x): ...@@ -304,6 +304,7 @@ def hartree2kcal(x):
from ..data import BatchedANIDataset # noqa: E402 from ..data import BatchedANIDataset # noqa: E402
from ..data import AEVCacheLoader # noqa: E402
class Trainer: class Trainer:
...@@ -315,12 +316,14 @@ class Trainer: ...@@ -315,12 +316,14 @@ class Trainer:
tqdm (bool): whether to enable tqdm tqdm (bool): whether to enable tqdm
tensorboard (str): Directory to store tensorboard log file, set to\ tensorboard (str): Directory to store tensorboard log file, set to\
``None`` to disable tensorboardX. ``None`` to disable tensorboardX.
aev_caching (bool): Whether to use AEV caching.
""" """
def __init__(self, filename, device=torch.device('cuda'), def __init__(self, filename, device=torch.device('cuda'),
tqdm=False, tensorboard=None): tqdm=False, tensorboard=None, aev_caching=False):
self.filename = filename self.filename = filename
self.device = device self.device = device
self.aev_caching = aev_caching
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
...@@ -528,7 +531,10 @@ class Trainer: ...@@ -528,7 +531,10 @@ class Trainer:
i = o i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules) atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.model = ANIModel([atomic_nets[s] for s in self.consts.species]) self.model = ANIModel([atomic_nets[s] for s in self.consts.species])
self.nnp = torch.nn.Sequential(self.aev_computer, self.model) if self.aev_caching:
self.nnp = self.model
else:
self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
self.container = Container({'energies': self.nnp}).to(self.device) self.container = Container({'energies': self.nnp}).to(self.device)
# losses # losses
...@@ -561,15 +567,23 @@ class Trainer: ...@@ -561,15 +567,23 @@ class Trainer:
return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']) return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file""" """Load training and validation dataset from file.
self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor, If AEV caching is enabled, then the arguments are path to the cache
self.training_batch_size, device=self.device, directory, otherwise it should be path to the dataset.
transform=[self.shift_energy.subtract_from_dataset]) """
self.validation_set = BatchedANIDataset( if self.aev_caching:
validation_path, self.consts.species_to_tensor, self.training_set = AEVCacheLoader(training_path)
self.validation_batch_size, device=self.device, self.validation_set = AEVCacheLoader(validation_path)
transform=[self.shift_energy.subtract_from_dataset]) else:
self.training_set = BatchedANIDataset(
training_path, self.consts.species_to_tensor,
self.training_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = BatchedANIDataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def run(self): def run(self):
"""Run the training""" """Run the training"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment