Commit 97243508 authored by sunxx1's avatar sunxx1
Browse files

添加DBnet代码

parents
# nv的卡上可以直接运行
```
conda create --name dbnet python=3.6.9
conda activate dbnet
conda init bash
conda activate dbnet
pip install -r requirement.txt -i https://mirrors.ustc.edu.cn/pypi/web/simple/
pip install natsort addict
bash multi_gpu_train.sh
```
# Real-time Scene Text Detection with Differentiable Binarization
**note**: some code is inherited from [MhLiao/DB](https://github.com/MhLiao/DB)
[中文解读](https://zhuanlan.zhihu.com/p/94677957)
![network](imgs/paper/db.jpg)
## update
2020-06-07: 添加灰度图训练,训练灰度图时需要在配置里移除`dataset.args.transforms.Normalize`
## Install Using Conda
```
conda env create -f environment.yml
git clone https://github.com/WenmuZhou/DBNet.pytorch.git
cd DBNet.pytorch/
```
or
## Install Manually
```bash
conda create -n dbnet python=3.6
conda activate dbnet
conda install ipython pip
# python dependencies
pip install -r requirement.txt
# install PyTorch with cuda-10.1
# Note that you can change the cudatoolkit version to the version you want.
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
# clone repo
git clone https://github.com/WenmuZhou/DBNet.pytorch.git
cd DBNet.pytorch/
```
## Requirements
* pytorch 1.4+
* torchvision 0.5+
* gcc 4.9+
## Download
TBD
## Data Preparation
Training data: prepare a text `train.txt` in the following format, use '\t' as a separator
```
./datasets/train/img/001.jpg ./datasets/train/gt/001.txt
```
Validation data: prepare a text `test.txt` in the following format, use '\t' as a separator
```
./datasets/test/img/001.jpg ./datasets/test/gt/001.txt
```
- Store images in the `img` folder
- Store groundtruth in the `gt` folder
The groundtruth can be `.txt` files, with the following format:
```
x1, y1, x2, y2, x3, y3, x4, y4, annotation
```
## Train
1. config the `dataset['train']['dataset'['data_path']'`,`dataset['validate']['dataset'['data_path']`in [config/icdar2015_resnet18_fpn_DBhead_polyLR.yaml](cconfig/icdar2015_resnet18_fpn_DBhead_polyLR.yaml)
* . single gpu train
```bash
bash singlel_gpu_train.sh
```
* . Multi-gpu training
```bash
bash multi_gpu_train.sh
```
## Test
[eval.py](tools/eval.py) is used to test model on test dataset
1. config `model_path` in [eval.sh](eval.sh)
2. use following script to test
```bash
bash eval.sh
```
## Predict
[predict.py](tools/predict.py) Can be used to inference on all images in a folder
1. config `model_path`,`input_folder`,`output_folder` in [predict.sh](predict.sh)
2. use following script to predict
```
bash predict.sh
```
You can change the `model_path` in the `predict.sh` file to your model location.
tips: if result is not good, you can change `thre` in [predict.sh](predict.sh)
The project is still under development.
<h2 id="Performance">Performance</h2>
### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4)
only train on ICDAR2015 dataset
| Method | image size (short size) |learning rate | Precision (%) | Recall (%) | F-measure (%) | FPS |
|:--------------------------:|:-------:|:--------:|:--------:|:------------:|:---------------:|:-----:|
| SynthText-Defrom-ResNet-18(paper) | 736 |0.007 | 86.8 | 78.4 | 82.3 | 48 |
| ImageNet-resnet18-FPN-DBHead |736 |1e-3| 87.03 | 75.06 | 80.6 | 43 |
| ImageNet-Defrom-Resnet18-FPN-DBHead |736 |1e-3| 88.61 | 73.84 | 80.56 | 36 |
| ImageNet-resnet50-FPN-DBHead |736 |1e-3| 88.06 | 77.14 | 82.24 | 27 |
| ImageNet-resnest50-FPN-DBHead |736 |1e-3| 88.18 | 76.27 | 81.78 | 27 |
### examples
TBD
### todo
- [x] mutil gpu training
### reference
1. https://arxiv.org/pdf/1911.08947.pdf
2. https://github.com/WenmuZhou/PANet.pytorch
3. https://github.com/MhLiao/DB
**If this repository helps you,please star it. Thanks.**
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
# DBnet
## 模型介绍
DBNet的是将二值化操作插入到分割网络中进行联合优化,这样网络可以自适应的预测图像中每一个像素点的阈值(区别去传统方法的固定阈值),从而可完全区分前景和背景的像素。
二值化阈值由网络学习得到,彻底将二值化这一步骤加入到网络里一起训练,这样最终的输出图对于阈值就会具有非常强的鲁棒性,在简化了后处理的同时提高了文本检测的效果。
## 模型结构
DBNet 模型网络结构主要分为 3 个模块:
- 第一模块(1):使用的是一个 FPN 结构,分为自底向上的卷积操作与自顶向下的上采样,以此来获取多尺度的特征。1 图下面部分是 3x3 的卷积操作,按照卷积公式分别获取原图大小比例的 `1/2、1/4、1/8、1/16、1/32` 的特征图;然后自顶向下进行上采样 x2,然后与自底向上生成的相同大小的特征图融合;融合之后再采用 3x3 的卷积消除上采样的混叠效应;最后对每层输出结果进行上采样,统一为 1/4 大小的特征图。
- 第二模块(2):将 1/4 大小的特征图经过一系列卷积和转置卷积的机构获取概率图 **P** 和阈值图 **T**,可参考 FCN 网络结构,目的是生成与原图一样大小的特征图 P 和 T。
- 第三模块(3):将特征图 P 和 T 经过 DB 方法(后续介绍)得到近似二值图。
## 数据集
在本测试中可以使用icdar2015数据集。
## DBnet训练
### 环境配置
提供[光源](https://www.sourcefind.cn/#/service-details)拉取的训练的docker镜像:
* 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py37-latest
* pip install -r requirements.txt
### 训练
将训练数据解压到datasets目录下。
训练命令:
./run.sh
## 性能和准确率数据
测试数据使用的是icdar2015数据集,使用的加速卡是DCU Z100L。
| 卡数 | 性能 | 精度 |
| :--: | :--------------: | :-----------------------------------: |
| 1 | 22.9 samples/sec | recall: 0.767070, precision: 0.894410 |
### 参考
https://github.com/WenmuZhou/DBNet.pytorch
from .base_trainer import BaseTrainer
from .base_dataset import BaseDataSet
\ No newline at end of file
# -*- coding: utf-8 -*-
# @Time : 2019/12/4 13:12
# @Author : zhoujun
import copy
from torch.utils.data import Dataset
from data_loader.modules import *
class BaseDataSet(Dataset):
def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None,
target_transform=None):
assert img_mode in ['RGB', 'BRG', 'GRAY']
self.ignore_tags = ignore_tags
self.data_list = self.load_data(data_path)
item_keys = ['img_path', 'img_name', 'text_polys', 'texts', 'ignore_tags']
for item in item_keys:
assert item in self.data_list[0], 'data_list from load_data must contains {}'.format(item_keys)
self.img_mode = img_mode
self.filter_keys = filter_keys
self.transform = transform
self.target_transform = target_transform
self._init_pre_processes(pre_processes)
def _init_pre_processes(self, pre_processes):
self.aug = []
if pre_processes is not None:
for aug in pre_processes:
if 'args' not in aug:
args = {}
else:
args = aug['args']
if isinstance(args, dict):
cls = eval(aug['type'])(**args)
else:
cls = eval(aug['type'])(args)
self.aug.append(cls)
def load_data(self, data_path: str) -> list:
"""
把数据加载为一个list:
:params data_path: 存储数据的文件夹或者文件
return a dict ,包含了,'img_path','img_name','text_polys','texts','ignore_tags'
"""
raise NotImplementedError
def apply_pre_processes(self, data):
for aug in self.aug:
data = aug(data)
return data
def __getitem__(self, index):
try:
data = copy.deepcopy(self.data_list[index])
im = cv2.imread(data['img_path'], 1 if self.img_mode != 'GRAY' else 0)
if self.img_mode == 'RGB':
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
data['img'] = im
data['shape'] = [im.shape[0], im.shape[1]]
data = self.apply_pre_processes(data)
if self.transform:
data['img'] = self.transform(data['img'])
data['text_polys'] = data['text_polys'].tolist()
if len(self.filter_keys):
data_dict = {}
for k, v in data.items():
if k not in self.filter_keys:
data_dict[k] = v
return data_dict
else:
return data
except:
return self.__getitem__(np.random.randint(self.__len__()))
def __len__(self):
return len(self.data_list)
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:50
# @Author : zhoujun
import os
import pathlib
import shutil
from pprint import pformat
import anyconfig
import torch
from utils import setup_logger
class BaseTrainer:
def __init__(self, config, model, criterion):
config['trainer']['output_dir'] = os.path.join(str(pathlib.Path(os.path.abspath(__name__)).parent),
config['trainer']['output_dir'])
config['name'] = config['name'] + '_' + model.name
self.save_dir = os.path.join(config['trainer']['output_dir'], config['name'])
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
if config['trainer']['resume_checkpoint'] == '' and config['trainer']['finetune_checkpoint'] == '':
shutil.rmtree(self.save_dir, ignore_errors=True)
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
self.global_step = 0
self.start_epoch = 0
self.config = config
self.model = model
self.criterion = criterion
# logger and tensorboard
self.tensorboard_enable = self.config['trainer']['tensorboard']
self.epochs = self.config['trainer']['epochs']
self.log_iter = self.config['trainer']['log_iter']
if config['local_rank'] == 0:
anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
self.logger_info(pformat(self.config))
# device
torch.manual_seed(self.config['trainer']['seed']) # 为CPU设置随机种子
if torch.cuda.device_count() > 0 and torch.cuda.is_available():
self.with_cuda = True
torch.backends.cudnn.benchmark = True
self.device = torch.device("cuda")
torch.cuda.manual_seed(self.config['trainer']['seed']) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(self.config['trainer']['seed']) # 为所有GPU设置随机种子
else:
self.with_cuda = False
self.device = torch.device("cpu")
self.logger_info('train with device {} and pytorch {}'.format(self.device, torch.__version__))
# metrics
self.metrics = {'recall': 0, 'precision': 0, 'hmean': 0, 'train_loss': float('inf'),'best_model_epoch':0}
self.optimizer = self._initialize('optimizer', torch.optim, model.parameters())
# resume or finetune
if self.config['trainer']['resume_checkpoint'] != '':
self._load_checkpoint(self.config['trainer']['resume_checkpoint'], resume=True)
elif self.config['trainer']['finetune_checkpoint'] != '':
self._load_checkpoint(self.config['trainer']['finetune_checkpoint'], resume=False)
if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
self.scheduler = self._initialize('lr_scheduler', torch.optim.lr_scheduler, self.optimizer)
self.model.to(self.device)
if self.tensorboard_enable and config['local_rank'] == 0:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(self.save_dir)
try:
# add graph
in_channels = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1
dummy_input = torch.zeros(1, in_channels, 640, 640).to(self.device)
self.writer.add_graph(self.model, dummy_input)
torch.cuda.empty_cache()
except:
import traceback
self.logger.error(traceback.format_exc())
self.logger.warn('add graph to tensorboard failed')
# 分布式训练
if torch.cuda.device_count() > 1:
local_rank = config['local_rank']
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False,
find_unused_parameters=True)
# make inverse Normalize
self.UN_Normalize = False
for t in self.config['dataset']['train']['dataset']['args']['transforms']:
if t['type'] == 'Normalize':
self.normalize_mean = t['args']['mean']
self.normalize_std = t['args']['std']
self.UN_Normalize = True
def train(self):
"""
Full training logic
"""
for epoch in range(self.start_epoch + 1, self.epochs + 1):
if self.config['distributed']:
self.train_loader.sampler.set_epoch(epoch)
self.epoch_result = self._train_epoch(epoch)
if self.config['lr_scheduler']['type'] != 'WarmupPolyLR':
self.scheduler.step()
self._on_epoch_finish()
if self.config['local_rank'] == 0 and self.tensorboard_enable:
self.writer.close()
self._on_train_finish()
def _train_epoch(self, epoch):
"""
Training logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError
def _eval(self, epoch):
"""
eval logic for an epoch
:param epoch: Current epoch number
"""
raise NotImplementedError
def _on_epoch_finish(self):
raise NotImplementedError
def _on_train_finish(self):
raise NotImplementedError
def _save_checkpoint(self, epoch, file_name):
"""
Saving checkpoints
:param epoch: current epoch number
:param log: logging information of the epoch
:param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
"""
state_dict = self.model.module.state_dict() if self.config['distributed'] else self.model.state_dict()
state = {
'epoch': epoch,
'global_step': self.global_step,
'state_dict': state_dict,
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'config': self.config,
'metrics': self.metrics
}
filename = os.path.join(self.checkpoint_dir, file_name)
torch.save(state, filename)
def _load_checkpoint(self, checkpoint_path, resume):
"""
Resume from saved checkpoints
:param checkpoint_path: Checkpoint path to be resumed
"""
self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.model.load_state_dict(checkpoint['state_dict'], strict=resume)
if resume:
self.global_step = checkpoint['global_step']
self.start_epoch = checkpoint['epoch']
self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
# self.scheduler.load_state_dict(checkpoint['scheduler'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
if 'metrics' in checkpoint:
self.metrics = checkpoint['metrics']
if self.with_cuda:
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
self.logger_info("resume from checkpoint {} (epoch {})".format(checkpoint_path, self.start_epoch))
else:
self.logger_info("finetune from checkpoint {}".format(checkpoint_path))
def _initialize(self, name, module, *args, **kwargs):
module_name = self.config[name]['type']
module_args = self.config[name]['args']
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
def inverse_normalize(self, batch_img):
if self.UN_Normalize:
batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0]
batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1]
batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2]
def logger_info(self, s):
if self.config['local_rank'] == 0:
self.logger.info(s)
name: DBNet
dataset:
train:
dataset:
type: ICDAR2015Dataset # 数据集类型
args:
data_path: # 一个存放 img_path \t gt_path的文件
- ''
pre_processes: # 数据的预处理过程,包含augment和标签制作
- type: IaaAugment # 使用imgaug进行变换
args:
- {'type':Fliplr, 'args':{'p':0.5}}
- {'type': Affine, 'args':{'rotate':[-10,10]}}
- {'type':Resize,'args':{'size':[0.5,3]}}
- type: EastRandomCropData
args:
size: [640,640]
max_tries: 50
keep_ratio: true
- type: MakeBorderMap
args:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- type: MakeShrinkMap
args:
shrink_ratio: 0.4
min_text_size: 8
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ''
validate:
dataset:
type: ICDAR2015Dataset
args:
data_path:
- ''
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
transforms:
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
filter_keys: []
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ICDARCollectFN
\ No newline at end of file
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: deformable_resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 50
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
\ No newline at end of file
name: DBNet
dataset:
train:
dataset:
type: SynthTextDataset # 数据集类型
args:
data_path: ''# SynthTextDataset 根目录
pre_processes: # 数据的预处理过程,包含augment和标签制作
- type: IaaAugment # 使用imgaug进行变换
args:
- {'type':Fliplr, 'args':{'p':0.5}}
- {'type': Affine, 'args':{'rotate':[-10,10]}}
- {'type':Resize,'args':{'size':[0.5,3]}}
- type: EastRandomCropData
args:
size: [640,640]
max_tries: 50
keep_ratio: true
- type: MakeBorderMap
args:
shrink_ratio: 0.4
- type: MakeShrinkMap
args:
shrink_ratio: 0.4
min_text_size: 8
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
filter_keys: ['img_path','img_name','text_polys','texts','ignore_tags','shape'] # 返回数据之前,从数据字典里删除的key
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ''
\ No newline at end of file
name: DBNet
base: ['config/SynthText.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path: ./datasets/SynthText
img_mode: RGB
loader:
batch_size: 2
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
\ No newline at end of file
name: DBNet
dataset:
train:
dataset:
type: ICDAR2015Dataset # 数据集类型
args:
data_path: # 一个存放 img_path \t gt_path的文件
- ''
pre_processes: # 数据的预处理过程,包含augment和标签制作
- type: IaaAugment # 使用imgaug进行变换
args:
- {'type':Fliplr, 'args':{'p':0.5}}
- {'type': Affine, 'args':{'rotate':[-10,10]}}
- {'type':Resize,'args':{'size':[0.5,3]}}
- type: EastRandomCropData
args:
size: [640,640]
max_tries: 50
keep_ratio: true
- type: MakeBorderMap
args:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- type: MakeShrinkMap
args:
shrink_ratio: 0.4
min_text_size: 8
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ''
validate:
dataset:
type: ICDAR2015Dataset
args:
data_path:
- ''
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
transforms:
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
filter_keys: []
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ICDARCollectFN
\ No newline at end of file
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: deformable_resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 50
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
\ No newline at end of file
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 600
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 32
shuffle: true
pin_memory: true
num_workers: 8
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 8
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: StepLR
args:
step_size: 10
gama: 0.8
trainer:
seed: 2
epochs: 500
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment