"examples/vscode:/vscode.git/clone" did not exist on "c399de396dbb464be0935f910703eff9f11667ad"
Unverified Commit 8ad4e307 authored by MissPenguin's avatar MissPenguin Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph

parents b1d31f8a 2062b509
...@@ -66,6 +66,7 @@ class StdTextDrawer(object): ...@@ -66,6 +66,7 @@ class StdTextDrawer(object):
corpus_list.append(corpus[0:i]) corpus_list.append(corpus[0:i])
text_input_list.append(text_input) text_input_list.append(text_input)
corpus = corpus[i:] corpus = corpus[i:]
i = 0
break break
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
char_x += char_size char_x += char_size
...@@ -78,7 +79,6 @@ class StdTextDrawer(object): ...@@ -78,7 +79,6 @@ class StdTextDrawer(object):
corpus_list.append(corpus[0:i]) corpus_list.append(corpus[0:i])
text_input_list.append(text_input) text_input_list.append(text_input)
corpus = corpus[i:]
break break
return corpus_list, text_input_list return corpus_list, text_input_list
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2", "Teacher"]
# key: maps
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
...@@ -17,7 +17,7 @@ Global: ...@@ -17,7 +17,7 @@ Global:
character_type: ch character_type: ch
max_text_length: 25 max_text_length: 25
infer_mode: false infer_mode: false
use_space_char: false use_space_char: true
distributed: true distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
...@@ -27,28 +27,29 @@ Optimizer: ...@@ -27,28 +27,29 @@ Optimizer:
beta1: 0.9 beta1: 0.9
beta2: 0.999 beta2: 0.999
lr: lr:
name: Cosine name: Piecewise
learning_rate: 0.0005 decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5 warmup_epoch: 5
regularizer: regularizer:
name: L2 name: L2
factor: 1.0e-05 factor: 2.0e-05
Architecture: Architecture:
model_type: &model_type "rec"
name: DistillationModel name: DistillationModel
algorithm: Distillation algorithm: Distillation
Models: Models:
Student: Teacher:
pretrained: pretrained:
freeze_params: false freeze_params: false
return_all_feats: true return_all_feats: true
model_type: rec model_type: *model_type
algorithm: CRNN algorithm: CRNN
Transform: Transform:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV1Enhance
scale: 0.5 scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
...@@ -56,19 +57,17 @@ Architecture: ...@@ -56,19 +57,17 @@ Architecture:
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96 mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00002
Teacher: Student:
pretrained: pretrained:
freeze_params: false freeze_params: false
return_all_feats: true return_all_feats: true
model_type: rec model_type: *model_type
algorithm: CRNN algorithm: CRNN
Transform: Transform:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV1Enhance
scale: 0.5 scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck: Neck:
name: SequenceEncoder name: SequenceEncoder
encoder_type: rnn encoder_type: rnn
...@@ -76,7 +75,7 @@ Architecture: ...@@ -76,7 +75,7 @@ Architecture:
Head: Head:
name: CTCHead name: CTCHead
mid_channels: 96 mid_channels: 96
fc_decay: 0.00001 fc_decay: 0.00002
Loss: Loss:
......
...@@ -29,7 +29,7 @@ deploy/hubserving/ocr_system/ ...@@ -29,7 +29,7 @@ deploy/hubserving/ocr_system/
### 1. 准备环境 ### 1. 准备环境
```shell ```shell
# 安装paddlehub # 安装paddlehub
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
``` ```
### 2. 下载推理模型 ### 2. 下载推理模型
......
...@@ -30,7 +30,7 @@ The following steps take the 2-stage series service as an example. If only the d ...@@ -30,7 +30,7 @@ The following steps take the 2-stage series service as an example. If only the d
### 1. Prepare the environment ### 1. Prepare the environment
```shell ```shell
# Install paddlehub # Install paddlehub
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
``` ```
### 2. Download inference model ### 2. Download inference model
......
...@@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT ...@@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
def export_single_model(quanter, model, infer_shape, save_path, logger):
quanter.save_quantized_model(
model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
logger.info('inference QAT model is saved to {}'.format(save_path))
def main(): def main():
############################################################################################################ ############################################################################################################
# 1. quantization configs # 1. quantization configs
...@@ -76,14 +87,21 @@ def main(): ...@@ -76,14 +87,21 @@ def main():
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
# get QAT model # get QAT model
quanter = QAT(config=quant_config) quanter = QAT(config=quant_config)
quanter.quantize(model) quanter.quantize(model)
init_model(config, model, logger) init_model(config, model)
model.eval() model.eval()
# build metric # build metric
...@@ -92,25 +110,30 @@ def main(): ...@@ -92,25 +110,30 @@ def main():
# build dataloader # build dataloader
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class, model_type, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, 32, 100] if config['Architecture'][ infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640] 'model_type'] != "det" else [3, 640, 640]
quanter.save_quantized_model( save_path = config["Global"]["save_inference_dir"]
model,
save_path, arch_config = config["Architecture"]
input_spec=[ if arch_config["algorithm"] in ["Distillation", ]: # distillation model
paddle.static.InputSpec( for idx, name in enumerate(model.model_name_list):
shape=[None] + infer_shape, dtype='float32') sub_model_save_path = os.path.join(save_path, name, "inference")
]) export_single_model(quanter, model.model_list[idx], infer_shape,
logger.info('inference QAT model is saved to {}'.format(save_path)) sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(quanter, model, infer_shape, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer): ...@@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer): ...@@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
......
# 知识蒸馏
## 1. 简介
### 1.1 知识蒸馏介绍
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
### 1.2 PaddleOCR知识蒸馏简介
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
## 2. 配置文件解析
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
### 2.1 识别配置文件解析
配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)
#### 2.1.1 模型结构
知识蒸馏任务中,模型结构配置如下所示。
```yaml
Architecture:
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: # 该子网络是否需要加载预训练模型
freeze_params: false # 是否需要固定参数
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: *model_type # 模型类别
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
pretrained: # 下面的组网参数同上
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student``Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
```yaml
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student``Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out``value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
```json
{
"Teacher": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
},
"Student": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
}
}
```
#### 2.1.2 损失函数
知识蒸馏任务中,损失函数配置如下所示。
```yaml
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
weight: 1.0 # 权重
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"]
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDistanceLoss: # 蒸馏的距离损失函数
weight: 1.0 # 权重
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
model_name_pairs: # 用于计算distance loss的子网络名称对
- ["Student", "Teacher"]
key: backbone_out # 取子网络输出dict中,该key对应的tensor
```
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
- `Student``Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
- `Student``Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
- `Student``Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)
#### 2.1.3 后处理
知识蒸馏任务中,后处理配置如下所示。
```yaml
PostProcess:
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
key: head_out # 取子网络输出dict中,该key对应的tensor
```
以上述配置为例,最终会同时计算`Student``Teahcer` 2个子网络的CTC解码输出,返回一个`dict``key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
#### 2.1.4 指标计算
知识蒸馏任务中,指标计算配置如下所示。
```yaml
Metric:
name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
main_indicator: acc # 指标的名称
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
```
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)
### 2.2 检测配置文件解析
* coming soon!
...@@ -375,7 +375,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi ...@@ -375,7 +375,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99) 更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。 多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件: 如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
......
...@@ -375,7 +375,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are: ...@@ -375,7 +375,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations) For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi. The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file: If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
......
doc/joinus.PNG

205 KB | W: | H:

doc/joinus.PNG

188 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -73,11 +73,14 @@ class CopyPaste(object): ...@@ -73,11 +73,14 @@ class CopyPaste(object):
box_img_pil = Image.fromarray(box_img).convert('RGBA') box_img_pil = Image.fromarray(box_img).convert('RGBA')
src_w, src_h = src_img.size src_w, src_h = src_img.size
box_w, box_h = box_img_pil.size box_w, box_h = box_img_pil.size
if box_w > src_w or box_h > src_h:
return src_img, None
angle = np.random.randint(0, 360) angle = np.random.randint(0, 360)
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
box = rotate_bbox(box_img, box, angle)[0] box = rotate_bbox(box_img, box, angle)[0]
box_img_pil = box_img_pil.rotate(angle, expand=1)
box_w, box_h = box_img_pil.width, box_img_pil.height
if src_w - box_w < 0 or src_h - box_h < 0:
return src_img, None
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
src_h - box_h) src_h - box_h)
...@@ -85,7 +88,6 @@ class CopyPaste(object): ...@@ -85,7 +88,6 @@ class CopyPaste(object):
return src_img, None return src_img, None
box[:, 0] += paste_x box[:, 0] += paste_x
box[:, 1] += paste_y box[:, 1] += paste_y
box_img_pil = box_img_pil.rotate(angle, expand=1)
r, g, b, A = box_img_pil.split() r, g, b, A = box_img_pil.split()
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
...@@ -105,7 +107,7 @@ class CopyPaste(object): ...@@ -105,7 +107,7 @@ class CopyPaste(object):
num_poly_in_rect = 0 num_poly_in_rect = 0
for poly in src_polys: for poly in src_polys:
if not is_poly_outside_rect(poly, xmax1, ymin1, if not is_poly_outside_rect(poly, xmin1, ymin1,
xmax1 - xmin1, ymax1 - ymin1): xmax1 - xmin1, ymax1 - ymin1):
num_poly_in_rect += 1 num_poly_in_rect += 1
break break
......
...@@ -54,6 +54,27 @@ class CELoss(nn.Layer): ...@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return loss return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2])
elif reduction=="none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1,2])
return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
...@@ -70,16 +91,20 @@ class DMLLoss(nn.Layer): ...@@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
else: else:
self.act = None self.act = None
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2:
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else:
loss = self.jskl_loss(out1, out2)
return loss return loss
......
...@@ -17,7 +17,7 @@ import paddle.nn as nn ...@@ -17,7 +17,7 @@ import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
...@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer): ...@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
def forward(self, input, batch, **kargs): def forward(self, input, batch, **kargs):
loss_dict = {} loss_dict = {}
loss_all = 0.
for idx, loss_func in enumerate(self.loss_func): for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs) loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
loss = { for key in loss.keys():
"{}_{}".format(key, idx): loss[key] * weight if key == "loss":
for key in loss loss_all += loss[key] * weight
} else:
loss_dict.update(loss) loss_dict["{}_{}".format(key, idx)] = loss[key]
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) loss_dict["loss"] = loss_all
return loss_dict return loss_dict
...@@ -14,23 +14,76 @@ ...@@ -14,23 +14,76 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss from .basic_loss import DistanceLoss
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
else:
loss_dict["loss"] = 0.
for k, value in loss_dict.items():
if k == "loss":
continue
else:
loss_dict["loss"] += value
return loss_dict
class DistillationDMLLoss(DMLLoss): class DistillationDMLLoss(DMLLoss):
""" """
""" """
def __init__(self, model_name_pairs=[], act=None, key=None, def __init__(self,
name="loss_dml"): model_name_pairs=[],
act=None,
key=None,
maps_name=None,
name="dml"):
super().__init__(act=act) super().__init__(act=act)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
...@@ -40,6 +93,8 @@ class DistillationDMLLoss(DMLLoss): ...@@ -40,6 +93,8 @@ class DistillationDMLLoss(DMLLoss):
if self.key is not None: if self.key is not None:
out1 = out1[self.key] out1 = out1[self.key]
out2 = out2[self.key] out2 = out2[self.key]
if self.maps_name is None:
loss = super().forward(out1, out2) loss = super().forward(out1, out2)
if isinstance(loss, dict): if isinstance(loss, dict):
for key in loss: for key in loss:
...@@ -47,6 +102,21 @@ class DistillationDMLLoss(DMLLoss): ...@@ -47,6 +102,21 @@ class DistillationDMLLoss(DMLLoss):
idx)] = loss[key] idx)] = loss[key]
else: else:
loss_dict["{}_{}".format(self.name, idx)] = loss loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict return loss_dict
...@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss): ...@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
return loss_dict return loss_dict
class DistillationDBLoss(DBLoss):
def __init__(self,
model_name_list=[],
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="db",
**kwargs):
super().__init__()
self.model_name_list = model_name_list
self.name = name
self.key = None
def forward(self, predicts, batch):
loss_dict = {}
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss.keys():
if key == "loss":
continue
name = "{}_{}_{}".format(self.name, model_name, key)
loss_dict[name] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDilaDBLoss(DBLoss):
def __init__(self,
model_name_pairs=[],
key=None,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="dila_dbloss"):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
stu_outs = predicts[pair[0]]
tch_outs = predicts[pair[1]]
if self.key is not None:
stu_preds = stu_outs[self.key]
tch_preds = tch_outs[self.key]
stu_shrink_maps = stu_preds[:, 0, :, :]
stu_binary_maps = stu_preds[:, 2, :, :]
# dilation to teacher prediction
dilation_w = np.array([[1, 1], [1, 1]])
th_shrink_maps = tch_preds[:, 0, :, :]
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]):
dilate_maps[i] = cv2.dilate(
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
th_shrink_maps = paddle.to_tensor(dilate_maps)
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
1:]
# calculate the shrink map loss
bce_loss = self.alpha * self.bce_loss(
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
label_shrink_mask)
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDistanceLoss(DistanceLoss): class DistillationDistanceLoss(DistanceLoss):
""" """
""" """
......
...@@ -55,6 +55,7 @@ class DetMetric(object): ...@@ -55,6 +55,7 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list) result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
""" """
return metrics { return metrics {
......
...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric ...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object): class DistillationMetric(object):
def __init__(self, def __init__(self,
key=None, key=None,
base_metric_name="RecMetric", base_metric_name=None,
main_indicator='acc', main_indicator=None,
**kwargs): **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.key = key self.key = key
...@@ -42,16 +42,13 @@ class DistillationMetric(object): ...@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs) main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset() self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs): def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict) assert isinstance(preds, dict)
if self.metrics is None: if self.metrics is None:
self._init_metrcis(preds) self._init_metrcis(preds)
output = dict() output = dict()
for key in preds: for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs) self.metrics[key].__call__(preds[key], batch, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self): def get_metric(self):
""" """
......
...@@ -79,6 +79,9 @@ class BaseModel(nn.Layer): ...@@ -79,6 +79,9 @@ class BaseModel(nn.Layer):
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
x = self.head(x, targets=data) x = self.head(x, targets=data)
if isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x y["head_out"] = x
if self.return_all_feats: if self.return_all_feats:
return y return y
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment