Unverified Commit 45f25ec3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

assert over-fitting in test_ignite (#33)

parent b3744935
......@@ -18,7 +18,7 @@ if sys.version_info.major >= 3:
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize)
ds = torchani.data.ANIDataset(path, chunksize, device=device)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
......
......@@ -5,14 +5,14 @@ if sys.version_info.major >= 3:
import unittest
import torch
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator
create_supervised_evaluator, Events
import torchani
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 32
batch_chunks = 32
chunksize = 16
threshold = 1e-5
dtype = torch.float32
device = torch.device('cpu')
......@@ -21,9 +21,10 @@ if sys.version_info.major >= 3:
def testIgnite(self):
shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset(
path, chunksize,
path, chunksize, device=device,
transform=[shift_energy.dataset_subtract_sae])
loader = torchani.data.dataloader(ds, batch_chunks)
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV(dtype=dtype, device=device)
nnp = torchani.models.NeuroChemNNP(aev_computer)
......@@ -35,18 +36,25 @@ if sys.version_info.major >= 3:
def forward(self, *input):
return self.model(*input).flatten()
nnp = Flatten(nnp)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.SGD(container.parameters(),
lr=0.001, momentum=0.8)
optimizer = torch.optim.Adam(container.parameters())
trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
trainer.run(loader, max_epochs=10)
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric
})
evaluator.run(loader)
@trainer.on(Events.COMPLETED)
def completes(trainer):
evaluator.run(loader)
metrics = evaluator.state.metrics
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
trainer.run(loader, max_epochs=1000)
if __name__ == '__main__':
unittest.main()
......@@ -221,13 +221,13 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
w = struct.unpack('{}f'.format(wsize), fw.read())
w = torch.tensor(w, dtype=self.dtype, device=self.device).view(
out_size, in_size)
linear.weight = torch.nn.parameter.Parameter(w, requires_grad=True)
linear.weight.data = w
fw.close()
fb = open(bfn, 'rb')
b = struct.unpack('{}f'.format(out_size), fb.read())
b = torch.tensor(b, dtype=self.dtype,
device=self.device).view(out_size)
linear.bias = torch.nn.parameter.Parameter(b, requires_grad=True)
linear.bias.data = b
fb.close()
def get_activations(self, aev, layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment