Unverified Commit 45f25ec3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

assert over-fitting in test_ignite (#33)

parent b3744935
...@@ -18,7 +18,7 @@ if sys.version_info.major >= 3: ...@@ -18,7 +18,7 @@ if sys.version_info.major >= 3:
class TestBatch(unittest.TestCase): class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self): def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize) ds = torchani.data.ANIDataset(path, chunksize, device=device)
loader = torchani.data.dataloader(ds, batch_chunks) loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer) nnp = torchani.models.NeuroChemNNP(aev_computer)
......
...@@ -5,14 +5,14 @@ if sys.version_info.major >= 3: ...@@ -5,14 +5,14 @@ if sys.version_info.major >= 3:
import unittest import unittest
import torch import torch
from ignite.engine import create_supervised_trainer, \ from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator create_supervised_evaluator, Events
import torchani import torchani
import torchani.data import torchani.data
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5') path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 32 chunksize = 16
batch_chunks = 32 threshold = 1e-5
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
...@@ -21,9 +21,10 @@ if sys.version_info.major >= 3: ...@@ -21,9 +21,10 @@ if sys.version_info.major >= 3:
def testIgnite(self): def testIgnite(self):
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset( ds = torchani.data.ANIDataset(
path, chunksize, path, chunksize, device=device,
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
loader = torchani.data.dataloader(ds, batch_chunks) ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device) aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer) nnp = torchani.models.NeuroChemNNP(aev_computer)
...@@ -35,18 +36,25 @@ if sys.version_info.major >= 3: ...@@ -35,18 +36,25 @@ if sys.version_info.major >= 3:
def forward(self, *input): def forward(self, *input):
return self.model(*input).flatten() return self.model(*input).flatten()
nnp = Flatten(nnp) nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.SGD(container.parameters(), optimizer = torch.optim.Adam(container.parameters())
lr=0.001, momentum=0.8)
trainer = create_supervised_trainer( trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.energy_mse_loss)
trainer.run(loader, max_epochs=10)
evaluator = create_supervised_evaluator(container, metrics={ evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric 'RMSE': torchani.ignite.energy_rmse_metric
}) })
evaluator.run(loader)
@trainer.on(Events.COMPLETED)
def completes(trainer):
evaluator.run(loader)
metrics = evaluator.state.metrics
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
trainer.run(loader, max_epochs=1000)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -221,13 +221,13 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -221,13 +221,13 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
w = struct.unpack('{}f'.format(wsize), fw.read()) w = struct.unpack('{}f'.format(wsize), fw.read())
w = torch.tensor(w, dtype=self.dtype, device=self.device).view( w = torch.tensor(w, dtype=self.dtype, device=self.device).view(
out_size, in_size) out_size, in_size)
linear.weight = torch.nn.parameter.Parameter(w, requires_grad=True) linear.weight.data = w
fw.close() fw.close()
fb = open(bfn, 'rb') fb = open(bfn, 'rb')
b = struct.unpack('{}f'.format(out_size), fb.read()) b = struct.unpack('{}f'.format(out_size), fb.read())
b = torch.tensor(b, dtype=self.dtype, b = torch.tensor(b, dtype=self.dtype,
device=self.device).view(out_size) device=self.device).view(out_size)
linear.bias = torch.nn.parameter.Parameter(b, requires_grad=True) linear.bias.data = b
fb.close() fb.close()
def get_activations(self, aev, layer): def get_activations(self, aev, layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment