Unverified Commit f73d8ae1 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use ignite.contrib.handlers.ProgressBar (#151)

parent fa852d5d
...@@ -15,11 +15,10 @@ AEVs. This example shows how to use disk cache to boost training ...@@ -15,11 +15,10 @@ AEVs. This example shows how to use disk cache to boost training
import torch import torch
import ignite import ignite
import torchani import torchani
import tqdm
import timeit import timeit
import tensorboardX import tensorboardX
import os import os
import sys import ignite.contrib.handlers
# training and validation set # training and validation set
...@@ -119,20 +118,10 @@ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ ...@@ -119,20 +118,10 @@ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
}) })
@trainer.on(ignite.engine.Events.EPOCH_STARTED) ###############################################################################
def init_tqdm(trainer): # Let's add a progress bar for the trainer
trainer.state.tqdm = tqdm.tqdm(total=len(training), pbar = ignite.contrib.handlers.ProgressBar()
file=sys.stdout, desc='epoch') pbar.attach(trainer)
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
def hartree2kcal(x): def hartree2kcal(x):
......
...@@ -13,11 +13,10 @@ This example shows how to use TorchANI train your own neural network potential. ...@@ -13,11 +13,10 @@ This example shows how to use TorchANI train your own neural network potential.
import torch import torch
import ignite import ignite
import torchani import torchani
import tqdm
import timeit import timeit
import tensorboardX import tensorboardX
import os import os
import sys import ignite.contrib.handlers
############################################################################### ###############################################################################
...@@ -153,21 +152,9 @@ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ ...@@ -153,21 +152,9 @@ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
############################################################################### ###############################################################################
# Now let's register some event handlers to work with tqdm to display progress: # Let's add a progress bar for the trainer
@trainer.on(ignite.engine.Events.EPOCH_STARTED) pbar = ignite.contrib.handlers.ProgressBar()
def init_tqdm(trainer): pbar.attach(trainer)
trainer.state.tqdm = tqdm.tqdm(total=len(training),
file=sys.stdout, desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
############################################################################### ###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment