Unverified Commit 0ed9b3f3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

temporarily disable l2 support for now, planning to reenable it after adaw get...

temporarily disable l2 support for now, planning to reenable it after adaw get merged into pytorch (#80)
parent bfa36346
...@@ -17,7 +17,11 @@ setup_attrs = { ...@@ -17,7 +17,11 @@ setup_attrs = {
'h5py', 'h5py',
], ],
'test_suite': 'nose.collector', 'test_suite': 'nose.collector',
'tests_require': ['nose'], 'tests_require': [
'nose',
'tensorboardX',
'tqdm',
],
} }
setup(**setup_attrs) setup(**setup_attrs)
No preview for this file type
...@@ -508,8 +508,16 @@ class Trainer: ...@@ -508,8 +508,16 @@ class Trainer:
del layer['activation'] del layer['activation']
if 'l2norm' in layer: if 'l2norm' in layer:
if layer['l2norm'] == 1: if layer['l2norm'] == 1:
c = layer['l2valu'] # NB: The "L2" implemented in NeuroChem is actually not
l2reg.append((c, module.weight)) # L2 but weight decay. The difference of these two is:
# https://arxiv.org/pdf/1711.05101.pdf
# There is a pull request on github/pytorch
# implementing AdamW, etc.:
# https://github.com/pytorch/pytorch/pull/4429
# There is no plan to support the "L2" settings in
# input file before AdamW get merged into pytorch.
raise NotImplementedError('L2 not supported yet')
l2reg.append((0.5 * layer['l2valu'], module))
del layer['l2norm'] del layer['l2norm']
del layer['l2valu'] del layer['l2valu']
if len(layer) > 0: if len(layer) > 0:
...@@ -522,7 +530,7 @@ class Trainer: ...@@ -522,7 +530,7 @@ class Trainer:
# losses # losses
def l2(): def l2():
return sum([c * w.norm(2) for c, w in l2reg]) return sum([c * (m.weight ** 2).sum() for c, m in l2reg])
self.mse_loss = TransformedLoss(MSELoss('energies'), self.mse_loss = TransformedLoss(MSELoss('energies'),
lambda x: x + l2()) lambda x: x + l2())
self.exp_loss = TransformedLoss( self.exp_loss = TransformedLoss(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment