Commit 1090525e authored by andyjpaddle's avatar andyjpaddle
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

parents cd21ea87 12aa5e80
from .vdl_logger import VDLLogger
from .wandb_logger import WandbLogger
from .loggers import Loggers
import os
from abc import ABC, abstractmethod
class BaseLogger(ABC):
def __init__(self, save_dir):
self.save_dir = save_dir
os.makedirs(self.save_dir, exist_ok=True)
@abstractmethod
def log_metrics(self, metrics, prefix=None):
pass
@abstractmethod
def close(self):
pass
\ No newline at end of file
from .wandb_logger import WandbLogger
class Loggers(object):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def log_metrics(self, metrics, prefix=None, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, prefix=prefix, step=step)
def log_model(self, is_best, prefix, metadata=None):
for logger in self.loggers:
logger.log_model(is_best=is_best, prefix=prefix, metadata=metadata)
def close(self):
for logger in self.loggers:
logger.close()
\ No newline at end of file
from .base_logger import BaseLogger
from visualdl import LogWriter
class VDLLogger(BaseLogger):
def __init__(self, save_dir):
super().__init__(save_dir)
self.vdl_writer = LogWriter(logdir=save_dir)
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
for k, v in updated_metrics.items():
self.vdl_writer.add_scalar(k, v, step)
def log_model(self, is_best, prefix, metadata=None):
pass
def close(self):
self.vdl_writer.close()
\ No newline at end of file
import os
from .base_logger import BaseLogger
class WandbLogger(BaseLogger):
def __init__(self,
project=None,
name=None,
id=None,
entity=None,
save_dir=None,
config=None,
**kwargs):
try:
import wandb
self.wandb = wandb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install wandb using `pip install wandb`"
)
self.project = project
self.name = name
self.id = id
self.save_dir = save_dir
self.config = config
self.kwargs = kwargs
self.entity = entity
self._run = None
self._wandb_init = dict(
project=self.project,
name=self.name,
id=self.id,
entity=self.entity,
dir=self.save_dir,
resume="allow"
)
self._wandb_init.update(**kwargs)
_ = self.run
if self.config:
self.run.config.update(self.config)
@property
def run(self):
if self._run is None:
if self.wandb.run is not None:
logger.info(
"There is a wandb run already in progress "
"and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()`"
"before instantiating `WandbLogger`."
)
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self._wandb_init)
return self._run
def log_metrics(self, metrics, prefix=None, step=None):
if not prefix:
prefix = ""
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
self.run.log(updated_metrics, step=step)
def log_model(self, is_best, prefix, metadata=None):
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
artifact.add_file(model_path, name="model_ckpt.pdparams")
aliases = [prefix]
if is_best:
aliases.append("best")
self.run.log_artifact(artifact, aliases=aliases)
def close(self):
self.run.finish()
\ No newline at end of file
...@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats ...@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict, AverageMeter from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
from ppocr.utils import profiler from ppocr.utils import profiler
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
...@@ -161,7 +162,7 @@ def train(config, ...@@ -161,7 +162,7 @@ def train(config,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None, log_writer=None,
scaler=None): scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
...@@ -300,10 +301,8 @@ def train(config, ...@@ -300,10 +301,8 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if vdl_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if dist.get_rank() == 0 and ( if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
...@@ -349,11 +348,9 @@ def train(config, ...@@ -349,11 +348,9 @@ def train(config,
logger.info(cur_metric_str) logger.info(cur_metric_str)
# logger metric # logger metric
if vdl_writer is not None: if log_writer is not None:
for k, v in cur_metric.items(): log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k),
cur_metric[k], global_step)
if cur_metric[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
best_model_dict.update(cur_metric) best_model_dict.update(cur_metric)
...@@ -374,10 +371,12 @@ def train(config, ...@@ -374,10 +371,12 @@ def train(config,
])) ]))
logger.info(best_str) logger.info(best_str)
# logger best metric # logger best metric
if vdl_writer is not None: if log_writer is not None:
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator), log_writer.log_metrics(metrics={
best_model_dict[main_indicator], "best_{}".format(main_indicator): best_model_dict[main_indicator]
global_step) }, prefix="EVAL", step=global_step)
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
...@@ -392,6 +391,10 @@ def train(config, ...@@ -392,6 +391,10 @@ def train(config,
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None:
log_writer.log_model(is_best=False, prefix="latest")
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model( save_model(
model, model,
...@@ -404,11 +407,14 @@ def train(config, ...@@ -404,11 +407,14 @@ def train(config,
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None: if dist.get_rank() == 0 and log_writer is not None:
vdl_writer.close() log_writer.close()
return return
...@@ -565,15 +571,32 @@ def preprocess(is_train=False): ...@@ -565,15 +571,32 @@ def preprocess(is_train=False):
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl'] and dist.get_rank() == 0: loggers = []
from visualdl import LogWriter
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True) log_writer = VDLLogger(save_model_dir)
vdl_writer = LogWriter(logdir=vdl_writer_path) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
wandb_params = config['wandb']
else:
wandb_params = dict()
wandb_params.update({'save_dir': save_model_dir})
log_writer = WandbLogger(**wandb_params, config=config)
loggers.append(log_writer)
else: else:
vdl_writer = None log_writer = None
print_dict(config, logger) print_dict(config, logger)
if loggers:
log_writer = Loggers(loggers)
else:
log_writer = None
logger.info('train with paddle {} and device {}'.format(paddle.__version__, logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device)) device))
return config, device, logger, vdl_writer return config, device, logger, log_writer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment