Unverified Commit a9df4b41 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

reduce time on neurochem trainer on CI (#208)

parent 276a886d
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
sflparamsfile=rHCNO-5.2R_16-3.5A_a4-8.params sflparamsfile=rHCNO-5.2R_16-3.5A_a4-8.params
ntwkStoreDir=networks/ ntwkStoreDir=networks/
atomEnergyFile=sae_linfit.dat atomEnergyFile=sae_linfit.dat
nmax=0! Maximum number of iterations (0 = inf) nmax=10! Maximum number of iterations (0 = inf)
tolr=1! Tolerance - early stopping tolr=1! Tolerance - early stopping
emult=0.1!Multiplier by eta after tol switch emult=0.1!Multiplier by eta after tol switch
eta=0.001! Eta -- Learning rate eta=0.001! Eta -- Learning rate
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
sflparamsfile: 'rHCNO-5.2R_16-3.5A_a4-8.params' sflparamsfile: 'rHCNO-5.2R_16-3.5A_a4-8.params'
ntwkStoreDir: 'networks/' ntwkStoreDir: 'networks/'
atomEnergyFile: 'sae_linfit.dat' atomEnergyFile: 'sae_linfit.dat'
nmax: 0 # Maximum number of iterations (0 = inf) nmax: 10 # Maximum number of iterations (0 = inf)
tolr: 1 # Tolerance - early stopping tolr: 1 # Tolerance - early stopping
emult: 0.1 # Multiplier by eta after tol switch emult: 0.1 # Multiplier by eta after tol switch
eta: 0.001 # Eta -- Learning rate eta: 0.001 # Eta -- Learning rate
......
...@@ -556,7 +556,6 @@ if sys.version_info[0] > 2: ...@@ -556,7 +556,6 @@ if sys.version_info[0] > 2:
assert_param('runtype', 'ANNP_CREATE_HDNN_AND_TRAIN') assert_param('runtype', 'ANNP_CREATE_HDNN_AND_TRAIN')
assert_param('adptlrn', 'OFF') assert_param('adptlrn', 'OFF')
assert_param('tmax', 0) assert_param('tmax', 0)
assert_param('nmax', 0)
assert_param('ntwshr', 0) assert_param('ntwshr', 0)
# load parameters # load parameters
...@@ -585,6 +584,8 @@ if sys.version_info[0] > 2: ...@@ -585,6 +584,8 @@ if sys.version_info[0] > 2:
del params['tbtchsz'] del params['tbtchsz']
self.validation_batch_size = params['vbtchsz'] self.validation_batch_size = params['vbtchsz']
del params['vbtchsz'] del params['vbtchsz']
self.nmax = params['nmax']
del params['nmax']
# construct networks # construct networks
input_size, network_setup = network_setup input_size, network_setup = network_setup
...@@ -700,6 +701,12 @@ if sys.version_info[0] > 2: ...@@ -700,6 +701,12 @@ if sys.version_info[0] > 2:
self.global_epoch = trainer.state.epoch self.global_epoch = trainer.state.epoch
self.global_iteration = trainer.state.iteration self.global_iteration = trainer.state.iteration
if self.nmax > 0:
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def terminate_when_nmax_reaches(trainer):
if trainer.state.epoch >= self.nmax:
trainer.terminate()
if self.tqdm is not None: if self.tqdm is not None:
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer): def init_tqdm(trainer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment