Commit fa675f89 authored by dyning's avatar dyning
Browse files

updata structure of dygraph

parent 7d09cd19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from .losses import build_loss
__all__ = ['build_model', 'build_loss']
def build_model(config):
from .architectures import Model
config = copy.deepcopy(config)
module_class = Model(config)
return module_class
......@@ -12,5 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .model import Model
__all__ = ['Model']
\ No newline at end of file
import copy
__all__ = ['build_model']
def build_model(config):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config)
return module_class
\ No newline at end of file
......@@ -15,38 +15,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append('/home/zhoujun20/PaddleOCR')
from paddle import nn
from ppocr.modeling.transform import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
__all__ = ['Model']
__all__ = ['BaseModel']
class Model(nn.Layer):
class BaseModel(nn.Layer):
def __init__(self, config):
"""
Detection module for OCR.
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super(Model, self).__init__()
algorithm = config['algorithm']
self.type = config['type']
self.model_name = '{}_{}'.format(self.type, algorithm)
super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
# if you make model differently, you can use transfrom in det and cls
if 'Transform' not in config or config['Transform'] is None:
self.use_transform = False
else:
......@@ -57,9 +48,9 @@ class Model(nn.Layer):
# build backbone, backbone is need for del, rec and cls
config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], self.type)
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
......@@ -71,6 +62,7 @@ class Model(nn.Layer):
config['Neck']['in_channels'] = in_channels
self.neck = build_neck(config['Neck'])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
......
......@@ -19,7 +19,6 @@ def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec':
from .rec_mobilenet_v3 import MobileNetV3
......
......@@ -130,7 +130,6 @@ class MobileNetV3(nn.Layer):
if_act=True,
act='hard_swish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
......@@ -275,4 +274,4 @@ class SEModule(nn.Layer):
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.hard_sigmoid(outputs)
return inputs * outputs
return inputs * outputs
\ No newline at end of file
......@@ -20,8 +20,8 @@ def build_head(config):
from .det_db_head import DBHead
# rec head
from .rec_ctc_head import CTC
support_dict = ['DBHead', 'CTC']
from .rec_ctc_head import CTCHead
support_dict = ['DBHead', 'CTCHead']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
......
......@@ -33,10 +33,9 @@ def get_para_bias_attr(l2_decay, k, name):
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr]
class CTC(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=1e-5, **kwargs):
super(CTC, self).__init__()
class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
self.fc = nn.Linear(
......
......@@ -14,11 +14,10 @@
__all__ = ['build_neck']
def build_neck(config):
from .fpn import FPN
from .db_fpn import DBFPN
from .rnn import SequenceEncoder
support_dict = ['FPN', 'SequenceEncoder']
support_dict = ['DBFPN', 'SequenceEncoder']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
......
......@@ -22,9 +22,9 @@ import paddle.nn.functional as F
from paddle import ParamAttr
class FPN(nn.Layer):
class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
super(FPN, self).__init__()
super(DBFPN, self).__init__()
self.out_channels = out_channels
weight_attr = paddle.nn.initializer.MSRA(uniform=False)
......
......@@ -76,8 +76,7 @@ class SequenceEncoder(nn.Layer):
'fc': EncoderWithFC,
'rnn': EncoderWithRNN
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
......
......@@ -50,6 +50,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer
optim_name = config.pop('name')
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
# weight_decay.
optim = getattr(optimizer, optim_name)(learning_rate=lr,
regularization=reg,
**config)
......
......@@ -40,8 +40,8 @@ class Momentum(object):
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
parameters=self.weight_decay,
weight_decay=parameters)
parameters=parameters,
weight_decay=self.weight_decay)
return opt
......
......@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode']
config = copy.deepcopy(config)
......
......@@ -46,7 +46,7 @@ def load_dygraph_pretrain(
model,
logger,
path=None,
load_static_weights=False, ):
load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
......@@ -110,21 +110,20 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_static_weights = gloabl_config.get('load_static_weights', False)
if pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(
pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model,
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format(
pretrained_model))
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(
pretrained_model)
for idx, pretrained in enumerate(pretrained_model):
load_static = load_static_weights[idx]
load_dygraph_pretrain(
model,
logger,
path=pretrained,
load_static_weights=load_static)
logger.info("load pretrained model from {}".format(
pretrained_model))
else:
logger.info('train from scratch')
return best_model_dict
......
......@@ -28,7 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser):
def __init__(self):
......@@ -136,18 +139,18 @@ def check_gpu(use_gpu):
def train(config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
train_dataloader,
valid_dataloader,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
global_step = 0
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
......@@ -156,6 +159,7 @@ def train(config,
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
global_step = 0
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
......@@ -179,14 +183,15 @@ def train(config,
start_epoch = 0
for epoch in range(start_epoch, epoch_num):
if epoch > 0:
train_loader = build_dataloader(config, 'Train', device)
for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader):
break
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
lr = optimizer.get_lr()
t1 = time.time()
batch = [paddle.to_variable(x) for x in batch]
batch = [paddle.to_tensor(x) for x in batch]
images = batch[0]
preds = model(images)
loss = loss_class(preds, batch)
......@@ -199,6 +204,8 @@ def train(config,
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
# logger and visualdl
stats = {k: v.numpy().mean() for k, v in loss.items()}
......@@ -228,8 +235,8 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metirc = eval(model, valid_dataloader,
post_process_class, eval_class, logger, print_batch_step)
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str)
......@@ -291,12 +298,14 @@ def train(config,
return
def eval(model, valid_dataloader, post_process_class, eval_class):
def eval(model, valid_dataloader,
post_process_class, eval_class,
logger, print_batch_step):
model.eval()
with paddle.no_grad():
total_frame = 0.0
total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model: ')
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader):
break
......@@ -310,11 +319,14 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time += time.time() - start
# Evaluate the results of the current batch
eval_class(post_result, batch)
pbar.update(1)
# pbar.update(1)
total_frame += len(images)
if idx % print_batch_step == 0:
logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric()
pbar.close()
# pbar.close()
model.train()
metirc['fps'] = total_frame / total_time
return metirc
......@@ -336,4 +348,25 @@ def preprocess():
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
return device, config
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path)
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer
......@@ -31,7 +31,8 @@ paddle.manual_seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
from ppocr.modeling import build_model, build_loss
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
......@@ -48,95 +49,76 @@ def main(config, device, logger, vdl_writer):
dist.init_parallel_env()
global_config = config['Global']
# build dataloader
train_loader, train_info_dict = build_dataloader(
config['TRAIN'], device, global_config['distributed'], global_config)
if config['EVAL']:
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
global_config)
train_dataloader = build_dataloader(config, 'Train', device)
if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device)
else:
eval_loader = None
valid_dataloader = None
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
post_process_class = build_post_process(
config['PostProcess'], global_config)
# build model
# for rec algorithm
#for rec algorithm
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# build loss
loss_class = build_loss(config['Loss'])
# build optim
optimizer, lr_scheduler = build_optimizer(
config['Optimizer'],
optimizer, lr_scheduler = build_optimizer(config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_loader),
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
best_model_dict = init_model(config, model, logger, optimizer)
# build loss
loss_class = build_loss(config['Loss'])
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
# start train
program.train(config, model, loss_class, optimizer, lr_scheduler,
train_loader, eval_loader, post_process_class, eval_class,
best_model_dict, logger, vdl_writer)
def test_reader(config, place, logger, global_config):
train_loader, _ = build_dataloader(
config['TRAIN'], place, global_config=global_config)
program.train(config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer)
def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device)
# loader = build_dataloader(config, 'Eval', device)
import time
starttime = time.time()
count = 0
try:
for data in train_loader:
for data in loader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
logger.info("reader: {}, {}, {}".format(
count, len(data[0]), batch_time))
logger.info("reader: {}, {}, {}".format(count, len(data), batch_time))
except Exception as e:
import traceback
traceback.print_exc()
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
def dis_main():
device, config = program.preprocess()
config['Global']['distributed'] = dist.get_world_size() != 1
paddle.disable_static(device)
# save_config
os.makedirs(config['Global']['save_model_dir'], exist_ok=True)
with open(
os.path.join(config['Global']['save_model_dir'], 'config.yml'),
'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(
log_file='{}/train.log'.format(config['Global']['save_model_dir']))
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer = LogWriter(logdir=config['Global']['save_model_dir'])
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger, config['Global'])
if __name__ == '__main__':
# main()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7')
dis_main()
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment