Commit c5cfa898 authored by Stefan Doerr's avatar Stefan Doerr Committed by Gao, Xiang
Browse files

Added a yaml construct method for Trainer (#144)

* added a yaml construct method for Trainer

* used safe_load

* reformatted the yaml structure to use the same _construct method

* added test for YAML config loader

* added pyyaml dep
parent cf192be4
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install tqdm ase tensorboardX pip install tqdm ase tensorboardX pyyaml
pip install pytorch-ignite --no-deps pip install pytorch-ignite --no-deps
...@@ -18,5 +18,8 @@ steps: ...@@ -18,5 +18,8 @@ steps:
- script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5' - script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.ipt dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5'
displayName: NeuroChem Trainer displayName: NeuroChem Trainer
- script: 'python -m torchani.neurochem.trainer --tqdm tests/test_data/inputtrain.yaml dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5'
displayName: NeuroChem Trainer YAML config
- script: 'python -m torchani.data.cache_aev tmp dataset/ani_gdb_s01.h5 256' - script: 'python -m torchani.data.cache_aev tmp dataset/ani_gdb_s01.h5 256'
displayName: Cache AEV displayName: Cache AEV
# InputFile for Force Prediction Network
sflparamsfile: 'rHCNO-5.2R_16-3.5A_a4-8.params'
ntwkStoreDir: 'networks/'
atomEnergyFile: 'sae_linfit.dat'
nmax: 0 # Maximum number of iterations (0 = inf)
tolr: 1 # Tolerance - early stopping
emult: 0.1 # Multiplier by eta after tol switch
eta: 0.001 # Eta -- Learning rate
tcrit: 1.0E-5 # Eta termination criterion
tmax: 0 # Maximum time (0 = inf)
tbtchsz: 2560
vbtchsz: 2560
gpuid: 0
ntwshr: 0
nkde: 2
energy: 1
force: 0
fmult: 0.0
pbc: 0
cmult: 0.001
runtype: 'ANNP_CREATE_HDNN_AND_TRAIN' # Create and train a HDN network
network_setup:
inputsize: 384
atom_net:
H:
- nodes: 160
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.0001
- nodes: 128
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.00001
- nodes: 96
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.000001
- nodes: 1
activation: 6
type: 0
C:
- nodes: 144
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.0001
- nodes: 112
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.00001
- nodes: 96
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.000001
- nodes: 1
activation: 6
type: 0
N:
- nodes: 128
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.0001
- nodes: 112
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.00001
- nodes: 96
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.000001
- nodes: 1
activation: 6
type: 0
O:
- nodes: 128
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.0001
- nodes: 112
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.00001
- nodes: 96
activation: 9
type: 0
# l2norm: 1
# l2valu: 0.000001
- nodes: 1
activation: 6
type: 0
adptlrn: 'OFF' # Adaptive learning (OFF,RMSPROP)
decrate: 0.9 # Decay rate of RMSPROP
moment: 'ADAM' # Turn on momentum or nesterov momentum (OFF,CNSTTEMP,TMANNEAL,REGULAR,NESTEROV)
mu: 0.99 # Mu factor for momentum
...@@ -337,8 +337,12 @@ class Trainer: ...@@ -337,8 +337,12 @@ class Trainer:
self.training_eval_every = 20 self.training_eval_every = 20
else: else:
self.tensorboard = None self.tensorboard = None
with open(filename, 'r') as f: with open(filename, 'r') as f:
network_setup, params = self._parse(f.read()) if filename.endswith('.yaml') or filename.endswith('.yml'):
network_setup, params = self._parse_yaml(f)
else:
network_setup, params = self._parse(f.read())
self._construct(network_setup, params) self._construct(network_setup, params)
def _parse(self, txt): def _parse(self, txt):
...@@ -438,6 +442,14 @@ class Trainer: ...@@ -438,6 +442,14 @@ class Trainer:
return TreeExec().transform(tree) return TreeExec().transform(tree)
def _parse_yaml(self, f):
import yaml
params = yaml.safe_load(f)
network_setup = params['network_setup']
del params['network_setup']
network_setup = (network_setup['inputsize'], network_setup['atom_net'])
return network_setup, params
def _construct(self, network_setup, params): def _construct(self, network_setup, params):
dir_ = os.path.dirname(os.path.abspath(self.filename)) dir_ = os.path.dirname(os.path.abspath(self.filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment