Unverified Commit 4452f68d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add warning for weight decay (#227)

parent 245614f7
......@@ -182,6 +182,13 @@ model = torch.nn.Sequential(aev_computer, nn).to(device)
# Also note that the weight decay only applies to weight in the training
# of ANI models, not bias.
#
# .. warning::
#
# Currently TorchANI training with weight decay can not reproduce the training
# result of NeuroChem with the same training setup. If you really want to use
# weight decay, consider smaller rates and and make sure you do enough validation
# to check if you get expected result.
#
# .. _Decoupled Weight Decay Regularization:
# https://arxiv.org/abs/1711.05101
optimizer = torchani.optim.AdamW([
......
......@@ -17,6 +17,8 @@ from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..optim import AdamW
import warnings
import textwrap
class Constants(collections.abc.Mapping):
......@@ -422,6 +424,8 @@ if sys.version_info[0] > 2:
self.imports.BatchedANIDataset = BatchedANIDataset
self.imports.AEVCacheLoader = AEVCacheLoader
self.warned = False
self.filename = filename
self.device = device
self.aev_caching = aev_caching
......@@ -632,6 +636,13 @@ if sys.version_info[0] > 2:
modules.append(activation)
del layer['activation']
if 'l2norm' in layer:
if not self.warned:
warnings.warn(textwrap.dedent("""
Currently TorchANI training with weight decay can not reproduce the training
result of NeuroChem with the same training setup. If you really want to use
weight decay, consider smaller rates and and make sure you do enough validation
to check if you get expected result."""))
self.warned = True
if layer['l2norm'] == 1:
self.parameters.append({
'params': [module.weight],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment