Commit 97243508 authored by sunxx1's avatar sunxx1
Browse files

添加DBnet代码

parents
name: DBNet
dataset:
train:
dataset:
type: DetDataset # 数据集类型
args:
data_path: # 一个存放 img_path \t gt_path的文件
- ''
pre_processes: # 数据的预处理过程,包含augment和标签制作
- type: IaaAugment # 使用imgaug进行变换
args:
- {'type':Fliplr, 'args':{'p':0.5}}
- {'type': Affine, 'args':{'rotate':[-10,10]}}
- {'type':Resize,'args':{'size':[0.5,3]}}
- type: EastRandomCropData
args:
size: [640,640]
max_tries: 50
keep_ratio: true
- type: MakeBorderMap
args:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- type: MakeShrinkMap
args:
shrink_ratio: 0.4
min_text_size: 8
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false
expand_one_char: false
filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ''
validate:
dataset:
type: DetDataset
args:
data_path:
- ''
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
transforms:
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false # 是否加载字符级标注
expand_one_char: false # 是否对只有一个字符的框进行宽度扩充,扩充后w = w+h
filter_keys: []
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 0
collate_fn: ICDARCollectFN
\ No newline at end of file
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: deformable_resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: resnest50
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
tensorboard: true
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
pin_memory: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
pin_memory: false
num_workers: 6
collate_fn: ICDARCollectFN
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import copy
import PIL
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
def get_dataset(data_path, module_name, transform, dataset_args):
"""
获取训练dataset
:param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
:param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
:param transform: 该数据集使用的transforms
:param dataset_args: module_name的参数
:return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
"""
from . import dataset
s_dataset = getattr(dataset, module_name)(transform=transform, data_path=data_path,
**dataset_args)
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
if 'args' not in item:
args = {}
else:
args = item['args']
cls = getattr(transforms, item['type'])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
class ICDARCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, torch.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = torch.stack(data_dict[k], 0)
return data_dict
def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
dataset_args = config['dataset']['args']
if 'transforms' in dataset_args:
img_transfroms = get_transforms(dataset_args.pop('transforms'))
else:
img_transfroms = None
# 创建数据集
dataset_name = config['dataset']['type']
data_path = dataset_args.pop('data_path')
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
if 'collate_fn' not in config['loader'] or config['loader']['collate_fn'] is None or len(config['loader']['collate_fn']) == 0:
config['loader']['collate_fn'] = None
else:
config['loader']['collate_fn'] = eval(config['loader']['collate_fn'])()
_dataset = get_dataset(data_path=data_path, module_name=dataset_name, transform=img_transfroms, dataset_args=dataset_args)
sampler = None
if distributed:
from torch.utils.data.distributed import DistributedSampler
# 3)使用DistributedSampler
sampler = DistributedSampler(_dataset)
config['loader']['shuffle'] = False
config['loader']['pin_memory'] = True
loader = DataLoader(dataset=_dataset, sampler=sampler, **config['loader'])
return loader
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:54
# @Author : zhoujun
import pathlib
import os
import cv2
import numpy as np
import scipy.io as sio
from tqdm.auto import tqdm
from base import BaseDataSet
from utils import order_points_clockwise, get_datalist, load,expand_polygon
class ICDAR2015Dataset(BaseDataSet):
def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs):
super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform)
def load_data(self, data_path: str) -> list:
data_list = get_datalist(data_path)
t_data_list = []
for img_path, label_path in data_list:
data = self._get_annotation(label_path)
if len(data['text_polys']) > 0:
item = {'img_path': img_path, 'img_name': pathlib.Path(img_path).stem}
item.update(data)
t_data_list.append(item)
else:
print('there is no suit bbox in {}'.format(label_path))
return t_data_list
def _get_annotation(self, label_path: str) -> dict:
boxes = []
texts = []
ignores = []
with open(label_path, encoding='utf-8', mode='r') as f:
for line in f.readlines():
params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
try:
box = order_points_clockwise(np.array(list(map(float, params[:8]))).reshape(-1, 2))
if cv2.contourArea(box) > 0:
boxes.append(box)
label = params[8]
texts.append(label)
ignores.append(label in self.ignore_tags)
except:
print('load label failed on {}'.format(label_path))
data = {
'text_polys': np.array(boxes),
'texts': texts,
'ignore_tags': ignores,
}
return data
class DetDataset(BaseDataSet):
def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, ignore_tags, transform=None, **kwargs):
self.load_char_annotation = kwargs['load_char_annotation']
self.expand_one_char = kwargs['expand_one_char']
super().__init__(data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform)
def load_data(self, data_path: str) -> list:
"""
从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
:param data_path:
:return:
"""
data_list = []
for path in data_path:
content = load(path)
for gt in tqdm(content['data_list'], desc='read file {}'.format(path)):
img_path = os.path.join(content['data_root'], gt['img_name'])
polygons = []
texts = []
illegibility_list = []
language_list = []
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
continue
if len(annotation['text']) > 1 and self.expand_one_char:
annotation['polygon'] = expand_polygon(annotation['polygon'])
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
illegibility_list.append(annotation['illegibility'])
language_list.append(annotation['language'])
if self.load_char_annotation:
for char_annotation in annotation['chars']:
if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
continue
polygons.append(char_annotation['polygon'])
texts.append(char_annotation['char'])
illegibility_list.append(char_annotation['illegibility'])
language_list.append(char_annotation['language'])
data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons),
'texts': texts, 'ignore_tags': illegibility_list})
return data_list
class SynthTextDataset(BaseDataSet):
def __init__(self, data_path: str, img_mode, pre_processes, filter_keys, transform=None, **kwargs):
self.transform = transform
self.dataRoot = pathlib.Path(data_path)
if not self.dataRoot.exists():
raise FileNotFoundError('Dataset folder is not exist.')
self.targetFilePath = self.dataRoot / 'gt.mat'
if not self.targetFilePath.exists():
raise FileExistsError('Target file is not exist.')
targets = {}
sio.loadmat(self.targetFilePath, targets, squeeze_me=True, struct_as_record=False,
variable_names=['imnames', 'wordBB', 'txt'])
self.imageNames = targets['imnames']
self.wordBBoxes = targets['wordBB']
self.transcripts = targets['txt']
super().__init__(data_path, img_mode, pre_processes, filter_keys, transform)
def load_data(self, data_path: str) -> list:
t_data_list = []
for imageName, wordBBoxes, texts in zip(self.imageNames, self.wordBBoxes, self.transcripts):
item = {}
wordBBoxes = np.expand_dims(wordBBoxes, axis=2) if (wordBBoxes.ndim == 2) else wordBBoxes
_, _, numOfWords = wordBBoxes.shape
text_polys = wordBBoxes.reshape([8, numOfWords], order='F').T # num_words * 8
text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2
transcripts = [word for line in texts for word in line.split()]
if numOfWords != len(transcripts):
continue
item['img_path'] = str(self.dataRoot / imageName)
item['img_name'] = (self.dataRoot / imageName).stem
item['text_polys'] = text_polys
item['texts'] = transcripts
item['ignore_tags'] = [x in self.ignore_tags for x in transcripts]
t_data_list.append(item)
return t_data_list
if __name__ == '__main__':
import torch
import anyconfig
from torch.utils.data import DataLoader
from torchvision import transforms
from utils import parse_config, show_img, plt, draw_bbox
config = anyconfig.load('config/icdar2015_resnet18_FPN_DBhead_polyLR.yaml')
config = parse_config(config)
dataset_args = config['dataset']['train']['dataset']['args']
# dataset_args.pop('data_path')
# data_list = [(r'E:/zj/dataset/icdar2015/train/img/img_15.jpg', 'E:/zj/dataset/icdar2015/train/gt/gt_img_15.txt')]
train_data = ICDAR2015Dataset(data_path=dataset_args.pop('data_path'), transform=transforms.ToTensor(),
**dataset_args)
train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=0)
for i, data in enumerate(tqdm(train_loader)):
# img = data['img']
# shrink_label = data['shrink_map']
# threshold_label = data['threshold_map']
#
# print(threshold_label.shape, threshold_label.shape, img.shape)
# show_img(img[0].numpy().transpose(1, 2, 0), title='img')
# show_img((shrink_label[0].to(torch.float)).numpy(), title='shrink_label')
# show_img((threshold_label[0].to(torch.float)).numpy(), title='threshold_label')
# img = draw_bbox(img[0].numpy().transpose(1, 2, 0),np.array(data['text_polys']))
# show_img(img, title='draw_bbox')
# plt.show()
pass
# -*- coding: utf-8 -*-
# @Time : 2019/12/4 10:53
# @Author : zhoujun
from .iaa_augment import IaaAugment
from .augment import *
from .random_crop_data import EastRandomCropData,PSERandomCrop
from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment