Unverified Commit 71b9c75f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add support for weight decay (#223)

parent 4cec6442
...@@ -26,22 +26,22 @@ network_setup { ...@@ -26,22 +26,22 @@ network_setup {
nodes=160; nodes=160;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.0001; l2valu=0.0001;
] ]
layer [ layer [
nodes=128; nodes=128;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.00001; l2valu=0.00001;
] ]
layer [ layer [
nodes=96; nodes=96;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.000001; l2valu=0.000001;
] ]
layer [ layer [
nodes=1; nodes=1;
...@@ -54,22 +54,22 @@ network_setup { ...@@ -54,22 +54,22 @@ network_setup {
nodes=144; nodes=144;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.0001; l2valu=0.0001;
] ]
layer [ layer [
nodes=112; nodes=112;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.00001; l2valu=0.00001;
] ]
layer [ layer [
nodes=96; nodes=96;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.000001; l2valu=0.000001;
] ]
layer [ layer [
nodes=1; nodes=1;
...@@ -82,22 +82,22 @@ network_setup { ...@@ -82,22 +82,22 @@ network_setup {
nodes=128; nodes=128;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.0001; l2valu=0.0001;
] ]
layer [ layer [
nodes=112; nodes=112;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.00001; l2valu=0.00001;
] ]
layer [ layer [
nodes=96; nodes=96;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.000001; l2valu=0.000001;
] ]
layer [ layer [
nodes=1; nodes=1;
...@@ -110,22 +110,22 @@ network_setup { ...@@ -110,22 +110,22 @@ network_setup {
nodes=128; nodes=128;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.0001; l2valu=0.0001;
] ]
layer [ layer [
nodes=112; nodes=112;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.00001; l2valu=0.00001;
] ]
layer [ layer [
nodes=96; nodes=96;
activation=9; activation=9;
type=0; type=0;
!l2norm=1; l2norm=1;
!l2valu=0.000001; l2valu=0.000001;
] ]
layer [ layer [
nodes=1; nodes=1;
......
...@@ -404,6 +404,7 @@ if sys.version_info[0] > 2: ...@@ -404,6 +404,7 @@ if sys.version_info[0] > 2:
self.device = device self.device = device
self.aev_caching = aev_caching self.aev_caching = aev_caching
self.checkpoint_name = checkpoint_name self.checkpoint_name = checkpoint_name
self.parameters = []
if tqdm: if tqdm:
import tqdm import tqdm
self.tqdm = tqdm.tqdm self.tqdm = tqdm.tqdm
...@@ -591,7 +592,6 @@ if sys.version_info[0] > 2: ...@@ -591,7 +592,6 @@ if sys.version_info[0] > 2:
input_size, network_setup = network_setup input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length: if input_size != self.aev_computer.aev_length:
raise ValueError('AEV size and input size does not match') raise ValueError('AEV size and input size does not match')
l2reg = []
atomic_nets = {} atomic_nets = {}
for atom_type in network_setup: for atom_type in network_setup:
layers = network_setup[atom_type] layers = network_setup[atom_type]
...@@ -611,18 +611,20 @@ if sys.version_info[0] > 2: ...@@ -611,18 +611,20 @@ if sys.version_info[0] > 2:
del layer['activation'] del layer['activation']
if 'l2norm' in layer: if 'l2norm' in layer:
if layer['l2norm'] == 1: if layer['l2norm'] == 1:
# NB: The "L2" implemented in NeuroChem is actually self.parameters.append({
# not L2 but weight decay. The difference of these 'params': module.parameters(),
# two is: 'weight_decay': layer['l2valu'],
# https://arxiv.org/pdf/1711.05101.pdf })
# There is a pull request on github/pytorch else:
# implementing AdamW, etc.: self.parameters.append({
# https://github.com/pytorch/pytorch/pull/4429 'params': module.parameters(),
# There is no plan to support the "L2" settings in })
# input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet')
del layer['l2norm'] del layer['l2norm']
del layer['l2valu'] del layer['l2valu']
else:
self.parameters.append({
'params': module.parameters(),
})
if layer: if layer:
raise ValueError( raise ValueError(
'unrecognized parameter in layer setup') 'unrecognized parameter in layer setup')
...@@ -637,13 +639,10 @@ if sys.version_info[0] > 2: ...@@ -637,13 +639,10 @@ if sys.version_info[0] > 2:
self.container = Container({'energies': self.nnp}).to(self.device) self.container = Container({'energies': self.nnp}).to(self.device)
# losses # losses
def l2(): self.mse_loss = MSELoss('energies')
return sum([c * (m.weight ** 2).sum() for c, m in l2reg])
self.mse_loss = TransformedLoss(MSELoss('energies'),
lambda x: x + l2())
self.exp_loss = TransformedLoss( self.exp_loss = TransformedLoss(
MSELoss('energies'), MSELoss('energies'),
lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2()) lambda x: 0.5 * (torch.exp(2 * x) - 1))
if params: if params:
raise ValueError('unrecognized parameter') raise ValueError('unrecognized parameter')
...@@ -776,7 +775,7 @@ if sys.version_info[0] > 2: ...@@ -776,7 +775,7 @@ if sys.version_info[0] > 2:
# training using mse loss first until the validation MAE decrease # training using mse loss first until the validation MAE decrease
# to < 1 Hartree # to < 1 Hartree
optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) optimizer = torch.optim.Adam(self.parameters, lr=lr)
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
self.container, optimizer, self.mse_loss) self.container, optimizer, self.mse_loss)
decorate(trainer) decorate(trainer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment