"doc/vscode:/vscode.git/clone" did not exist on "056b7606f6da22b84d53899553e06d78b3b3cc03"
Unverified Commit 96c91907 authored by dyning's avatar dyning Committed by GitHub
Browse files

Merge pull request #1105 from dyning/dygraph

updata structure of dygraph
parents 7d09cd19 1ae37919
...@@ -24,8 +24,8 @@ __all__ = ['build_post_process'] ...@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode'] support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode']
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO): ...@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0: if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0] log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True) os.makedirs(log_file_folder, exist_ok=True)
......
...@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain( def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
model,
logger,
path=None,
load_static_weights=False, ):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
if load_static_weights: if load_static_weights:
pre_state_dict = paddle.io.load_program_state(path) pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {} param_state_dict = {}
model_dict = model.state_dict() model_dict = model.state_dict()
for key in model_dict.keys(): for key in model_dict.keys():
...@@ -110,19 +106,14 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -110,19 +106,14 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
load_static_weights = gloabl_config.get('load_static_weights', False) load_static_weights = gloabl_config.get('load_static_weights', False)
if pretrained_model:
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list): if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len( load_static_weights = [load_static_weights] * len(pretrained_model)
pretrained_model)
for idx, pretrained in enumerate(pretrained_model): for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx] load_static = load_static_weights[idx]
load_dygraph_pretrain( load_dygraph_pretrain(
model, model, logger, path=pretrained, load_static_weights=load_static)
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from paddle.jit import to_static
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from tools.program import load_config
from tools.program import merge_config
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 32, None], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main():
FLAGS = parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
#for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
model = Model(model)
paddle.jit.save(model, FLAGS.output_path)
if __name__ == "__main__":
main()
...@@ -28,6 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter ...@@ -28,6 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
...@@ -136,18 +140,18 @@ def check_gpu(use_gpu): ...@@ -136,18 +140,18 @@ def check_gpu(use_gpu):
def train(config, def train(config,
train_dataloader,
valid_dataloader,
device,
model, model,
loss_class, loss_class,
optimizer, optimizer,
lr_scheduler, lr_scheduler,
train_dataloader,
valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None): vdl_writer=None):
global_step = 0
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
...@@ -156,6 +160,7 @@ def train(config, ...@@ -156,6 +160,7 @@ def train(config,
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -179,26 +184,24 @@ def train(config, ...@@ -179,26 +184,24 @@ def train(config,
start_epoch = 0 start_epoch = 0
for epoch in range(start_epoch, epoch_num): for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_loader = build_dataloader(config, 'Train', device)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader): if idx >= len(train_dataloader):
break break
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
lr = optimizer.get_lr() lr = optimizer.get_lr()
t1 = time.time() t1 = time.time()
batch = [paddle.to_variable(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
images = batch[0] images = batch[0]
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
if config['Global']['distributed']:
avg_loss = model.scale_loss(avg_loss)
avg_loss.backward()
model.apply_collective_grads()
else:
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
# logger and visualdl # logger and visualdl
stats = {k: v.numpy().mean() for k, v in loss.items()} stats = {k: v.numpy().mean() for k, v in loss.items()}
...@@ -220,7 +223,8 @@ def train(config, ...@@ -220,7 +223,8 @@ def train(config,
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
if global_step > 0 and global_step % print_batch_step == 0: if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
epoch, epoch_num, global_step, logs, train_batch_elapse) epoch, epoch_num, global_step, logs, train_batch_elapse)
...@@ -229,7 +233,7 @@ def train(config, ...@@ -229,7 +233,7 @@ def train(config,
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class) eval_class, logger, print_batch_step)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str) logger.info(cur_metirc_str)
...@@ -291,16 +295,17 @@ def train(config, ...@@ -291,16 +295,17 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader, post_process_class, eval_class, logger,
print_batch_step):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model: ') # pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= len(valid_dataloader):
break break
images = paddle.to_variable(batch[0]) images = paddle.to_tensor(batch[0])
start = time.time() start = time.time()
preds = model(images) preds = model(images)
...@@ -310,11 +315,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -310,11 +315,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) # pbar.update(1)
total_frame += len(images) total_frame += len(images)
if idx % print_batch_step == 0 and dist.get_rank() == 0:
logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
pbar.close()
# pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metirc['fps'] = total_frame / total_time
return metirc return metirc
...@@ -336,4 +345,24 @@ def preprocess(): ...@@ -336,4 +345,24 @@ def preprocess():
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
return device, config
config['Global']['distributed'] = dist.get_world_size() != 1
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path)
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer
...@@ -27,11 +27,11 @@ import yaml ...@@ -27,11 +27,11 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
paddle.manual_seed(2) paddle.seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling import build_model, build_loss from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
...@@ -48,95 +48,69 @@ def main(config, device, logger, vdl_writer): ...@@ -48,95 +48,69 @@ def main(config, device, logger, vdl_writer):
dist.init_parallel_env() dist.init_parallel_env()
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
train_loader, train_info_dict = build_dataloader( train_dataloader = build_dataloader(config, 'Train', device, logger)
config['TRAIN'], device, global_config['distributed'], global_config) if config['Eval']:
if config['EVAL']: valid_dataloader = build_dataloader(config, 'Eval', device, logger)
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
global_config)
else: else:
eval_loader = None valid_dataloader = None
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
global_config) global_config)
# build model # build model
# for rec algorithm #for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
# build optim # build optim
optimizer, lr_scheduler = build_optimizer( optimizer, lr_scheduler = build_optimizer(
config['Optimizer'], config['Optimizer'],
epochs=config['Global']['epoch_num'], epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_loader), step_each_epoch=len(train_dataloader),
parameters=model.parameters()) parameters=model.parameters())
best_model_dict = init_model(config, model, logger, optimizer)
# build loss
loss_class = build_loss(config['Loss'])
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
# start train # start train
program.train(config, model, loss_class, optimizer, lr_scheduler, program.train(config, train_dataloader, valid_dataloader, device, model,
train_loader, eval_loader, post_process_class, eval_class, loss_class, optimizer, lr_scheduler, post_process_class,
best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer)
def test_reader(config, place, logger, global_config): def test_reader(config, device, logger):
train_loader, _ = build_dataloader( loader = build_dataloader(config, 'Train', device)
config['TRAIN'], place, global_config=global_config) # loader = build_dataloader(config, 'Eval', device)
import time import time
starttime = time.time() starttime = time.time()
count = 0 count = 0
try: try:
for data in train_loader: for data in loader():
count += 1 count += 1
if count % 1 == 0: if count % 1 == 0:
batch_time = time.time() - starttime batch_time = time.time() - starttime
starttime = time.time() starttime = time.time()
logger.info("reader: {}, {}, {}".format( logger.info("reader: {}, {}, {}".format(count,
count, len(data[0]), batch_time)) len(data), batch_time))
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
logger.info(e) logger.info(e)
logger.info("finish reader: {}, Success!".format(count)) logger.info("finish reader: {}, Success!".format(count))
def dis_main():
device, config = program.preprocess()
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
os.makedirs(config['Global']['save_model_dir'], exist_ok=True)
with open(
os.path.join(config['Global']['save_model_dir'], 'config.yml'),
'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(
log_file='{}/train.log'.format(config['Global']['save_model_dir']))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer = LogWriter(logdir=config['Global']['save_model_dir'])
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger, config['Global'])
if __name__ == '__main__': if __name__ == '__main__':
# main() config, device, logger, vdl_writer = program.preprocess()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7') main(config, device, logger, vdl_writer)
dis_main() # test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment