Commit ed43fc11 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2703 canceled with stages
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: StepLR
args:
step_size: 10
gama: 0.8
trainer:
seed: 2
epochs: 500
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/icdar2015.yaml']
arch:
type: Model
backbone:
type: resnet50
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
lr_scheduler:
type: Polynomial
args:
learning_rate: 0.001
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 10
show_images_iter: 50
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output/fp16_o2
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.txt
img_mode: RGB
loader:
batch_size: 16
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.txt
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
dataset:
train:
dataset:
type: DetDataset # 数据集类型
args:
data_path: # 一个存放 img_path \t gt_path的文件
- ''
pre_processes: # 数据的预处理过程,包含augment和标签制作
- type: IaaAugment # 使用imgaug进行变换
args:
- {'type':Fliplr, 'args':{'p':0.5}}
- {'type': Affine, 'args':{'rotate':[-10,10]}}
- {'type':Resize,'args':{'size':[0.5,3]}}
- type: EastRandomCropData
args:
size: [640,640]
max_tries: 50
keep_ratio: true
- type: MakeBorderMap
args:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- type: MakeShrinkMap
args:
shrink_ratio: 0.4
min_text_size: 8
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false
expand_one_char: false
filter_keys: [img_path,img_name,text_polys,texts,ignore_tags,shape] # 返回数据之前,从数据字典里删除的key
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
num_workers: 0
collate_fn: ''
validate:
dataset:
type: DetDataset
args:
data_path:
- ''
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
transforms:
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false # 是否加载字符级标注
expand_one_char: false # 是否对只有一个字符的框进行宽度扩充,扩充后w = w+h
filter_keys: []
ignore_tags: ['*', '###']
loader:
batch_size: 1
shuffle: true
num_workers: 0
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: deformable_resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: resnest50
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
name: DBNet
base: ['config/open_dataset.yaml']
arch:
type: Model
backbone:
type: resnet18
pretrained: true
neck:
type: FPN
inner_channels: 256
head:
type: DBHead
out_channels: 2
k: 50
post_processing:
type: SegDetectorRepresenter
args:
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5 # from paper
metric:
type: QuadMetric
args:
is_output_polygon: false
loss:
type: DBLoss
alpha: 1
beta: 10
ohem_ratio: 3
optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0
amsgrad: true
lr_scheduler:
type: WarmupPolyLR
args:
warmup_epoch: 3
trainer:
seed: 2
epochs: 1200
log_iter: 1
show_images_iter: 1
resume_checkpoint: ''
finetune_checkpoint: ''
output_dir: output
visual_dl: false
amp:
scale_loss: 1024
amp_level: O2
custom_white_list: []
custom_black_list: ['exp', 'sigmoid', 'concat']
dataset:
train:
dataset:
args:
data_path:
- ./datasets/train.json
transforms: # 对图片进行的变换方式
- type: ToTensor
args: {}
- type: Normalize
args:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 2
shuffle: true
num_workers: 6
collate_fn: ''
validate:
dataset:
args:
data_path:
- ./datasets/test.json
pre_processes:
- type: ResizeShortSize
args:
short_size: 736
resize_text_polys: false
img_mode: RGB
load_char_annotation: false
expand_one_char: false
loader:
batch_size: 1
shuffle: true
num_workers: 6
collate_fn: ICDARCollectFN
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import copy
import PIL
import numpy as np
import paddle
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddle.vision import transforms
def get_dataset(data_path, module_name, transform, dataset_args):
"""
获取训练dataset
:param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/img\tlabel’
:param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
:param transform: 该数据集使用的transforms
:param dataset_args: module_name的参数
:return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
"""
from . import dataset
s_dataset = getattr(dataset, module_name)(
transform=transform, data_path=data_path, **dataset_args
)
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
if "args" not in item:
args = {}
else:
args = item["args"]
cls = getattr(transforms, item["type"])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
class ICDARCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.stack(data_dict[k], 0)
return data_dict
def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
dataset_args = config["dataset"]["args"]
if "transforms" in dataset_args:
img_transforms = get_transforms(dataset_args.pop("transforms"))
else:
img_transforms = None
# 创建数据集
dataset_name = config["dataset"]["type"]
data_path = dataset_args.pop("data_path")
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
if (
"collate_fn" not in config["loader"]
or config["loader"]["collate_fn"] is None
or len(config["loader"]["collate_fn"]) == 0
):
config["loader"]["collate_fn"] = None
else:
config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transforms,
dataset_args=dataset_args,
)
sampler = None
if distributed:
# 3)使用DistributedSampler
batch_sampler = DistributedBatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
else:
batch_sampler = BatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
loader = DataLoader(
dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
)
return loader
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:54
# @Author : zhoujun
import pathlib
import os
import cv2
import numpy as np
import scipy.io as sio
from tqdm.auto import tqdm
from base import BaseDataSet
from utils import order_points_clockwise, get_datalist, load, expand_polygon
class ICDAR2015Dataset(BaseDataSet):
def __init__(
self,
data_path: str,
img_mode,
pre_processes,
filter_keys,
ignore_tags,
transform=None,
**kwargs,
):
super().__init__(
data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
)
def load_data(self, data_path: str) -> list:
data_list = get_datalist(data_path)
t_data_list = []
for img_path, label_path in data_list:
data = self._get_annotation(label_path)
if len(data["text_polys"]) > 0:
item = {"img_path": img_path, "img_name": pathlib.Path(img_path).stem}
item.update(data)
t_data_list.append(item)
else:
print("there is no suit bbox in {}".format(label_path))
return t_data_list
def _get_annotation(self, label_path: str) -> dict:
boxes = []
texts = []
ignores = []
with open(label_path, encoding="utf-8", mode="r") as f:
for line in f.readlines():
params = line.strip().strip("\ufeff").strip("\xef\xbb\xbf").split(",")
try:
box = order_points_clockwise(
np.array(list(map(float, params[:8]))).reshape(-1, 2)
)
if cv2.contourArea(box) > 0:
boxes.append(box)
label = params[8]
texts.append(label)
ignores.append(label in self.ignore_tags)
except:
print("load label failed on {}".format(label_path))
data = {
"text_polys": np.array(boxes),
"texts": texts,
"ignore_tags": ignores,
}
return data
class DetDataset(BaseDataSet):
def __init__(
self,
data_path: str,
img_mode,
pre_processes,
filter_keys,
ignore_tags,
transform=None,
**kwargs,
):
self.load_char_annotation = kwargs["load_char_annotation"]
self.expand_one_char = kwargs["expand_one_char"]
super().__init__(
data_path, img_mode, pre_processes, filter_keys, ignore_tags, transform
)
def load_data(self, data_path: str) -> list:
"""
从json文件中读取出 文本行的坐标和gt,字符的坐标和gt
:param data_path:
:return:
"""
data_list = []
for path in data_path:
content = load(path)
for gt in tqdm(content["data_list"], desc="read file {}".format(path)):
img_path = os.path.join(content["data_root"], gt["img_name"])
polygons = []
texts = []
illegibility_list = []
language_list = []
for annotation in gt["annotations"]:
if len(annotation["polygon"]) == 0 or len(annotation["text"]) == 0:
continue
if len(annotation["text"]) > 1 and self.expand_one_char:
annotation["polygon"] = expand_polygon(annotation["polygon"])
polygons.append(annotation["polygon"])
texts.append(annotation["text"])
illegibility_list.append(annotation["illegibility"])
language_list.append(annotation["language"])
if self.load_char_annotation:
for char_annotation in annotation["chars"]:
if (
len(char_annotation["polygon"]) == 0
or len(char_annotation["char"]) == 0
):
continue
polygons.append(char_annotation["polygon"])
texts.append(char_annotation["char"])
illegibility_list.append(char_annotation["illegibility"])
language_list.append(char_annotation["language"])
data_list.append(
{
"img_path": img_path,
"img_name": gt["img_name"],
"text_polys": np.array(polygons),
"texts": texts,
"ignore_tags": illegibility_list,
}
)
return data_list
class SynthTextDataset(BaseDataSet):
def __init__(
self,
data_path: str,
img_mode,
pre_processes,
filter_keys,
transform=None,
**kwargs,
):
self.transform = transform
self.dataRoot = pathlib.Path(data_path)
if not self.dataRoot.exists():
raise FileNotFoundError("Dataset folder is not exist.")
self.targetFilePath = self.dataRoot / "gt.mat"
if not self.targetFilePath.exists():
raise FileExistsError("Target file is not exist.")
targets = {}
sio.loadmat(
self.targetFilePath,
targets,
squeeze_me=True,
struct_as_record=False,
variable_names=["imnames", "wordBB", "txt"],
)
self.imageNames = targets["imnames"]
self.wordBBoxes = targets["wordBB"]
self.transcripts = targets["txt"]
super().__init__(data_path, img_mode, pre_processes, filter_keys, transform)
def load_data(self, data_path: str) -> list:
t_data_list = []
for imageName, wordBBoxes, texts in zip(
self.imageNames, self.wordBBoxes, self.transcripts
):
item = {}
wordBBoxes = (
np.expand_dims(wordBBoxes, axis=2)
if (wordBBoxes.ndim == 2)
else wordBBoxes
)
_, _, numOfWords = wordBBoxes.shape
text_polys = wordBBoxes.reshape(
[8, numOfWords], order="F"
).T # num_words * 8
text_polys = text_polys.reshape(numOfWords, 4, 2) # num_of_words * 4 * 2
transcripts = [word for line in texts for word in line.split()]
if numOfWords != len(transcripts):
continue
item["img_path"] = str(self.dataRoot / imageName)
item["img_name"] = (self.dataRoot / imageName).stem
item["text_polys"] = text_polys
item["texts"] = transcripts
item["ignore_tags"] = [x in self.ignore_tags for x in transcripts]
t_data_list.append(item)
return t_data_list
# -*- coding: utf-8 -*-
# @Time : 2019/12/4 10:53
# @Author : zhoujun
from .iaa_augment import IaaAugment
from .augment import *
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import math
import numbers
import random
import cv2
import numpy as np
from skimage.util import random_noise
class RandomNoise:
def __init__(self, random_rate):
self.random_rate = random_rate
def __call__(self, data: dict):
"""
对图片加噪声
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
data["img"] = (
random_noise(data["img"], mode="gaussian", clip=True) * 255
).astype(data["img"].dtype)
return data
class RandomScale:
def __init__(self, scales, random_rate):
"""
:param scales: 尺度
:param random_rate: 随机系数
:return:
"""
self.random_rate = random_rate
self.scales = scales
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
im = data["img"]
text_polys = data["text_polys"]
tmp_text_polys = text_polys.copy()
rd_scale = float(np.random.choice(self.scales))
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
tmp_text_polys *= rd_scale
data["img"] = im
data["text_polys"] = tmp_text_polys
return data
class RandomRotateImgBox:
def __init__(self, degrees, random_rate, same_size=False):
"""
:param degrees: 角度,可以是一个数值或者list
:param random_rate: 随机系数
:param same_size: 是否保持和原图一样大
:return:
"""
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
degrees = (-degrees, degrees)
elif (
isinstance(degrees, list)
or isinstance(degrees, tuple)
or isinstance(degrees, np.ndarray)
):
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
degrees = degrees
else:
raise Exception("degrees must in Number or list or tuple or np.ndarray")
self.degrees = degrees
self.same_size = same_size
self.random_rate = random_rate
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
im = data["img"]
text_polys = data["text_polys"]
# ---------------------- 旋转图像 ----------------------
w = im.shape[1]
h = im.shape[0]
angle = np.random.uniform(self.degrees[0], self.degrees[1])
if self.same_size:
nw = w
nh = h
else:
# 角度变弧度
rangle = np.deg2rad(angle)
# 计算旋转之后图像的w, h
nw = abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)
nh = abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)
# 构造仿射矩阵
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
# 计算原图中心点到新图中心点的偏移量
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
# 更新仿射矩阵
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
# 仿射变换
rot_img = cv2.warpAffine(
im,
rot_mat,
(int(math.ceil(nw)), int(math.ceil(nh))),
flags=cv2.INTER_LANCZOS4,
)
# ---------------------- 矫正bbox坐标 ----------------------
# rot_mat是最终的旋转矩阵
# 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
rot_text_polys = list()
for bbox in text_polys:
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
rot_text_polys.append([point1, point2, point3, point4])
data["img"] = rot_img
data["text_polys"] = np.array(rot_text_polys)
return data
class RandomResize:
def __init__(self, size, random_rate, keep_ratio=False):
"""
:param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
:param random_rate: 随机系数
:param keep_ratio: 是否保持长宽比
:return:
"""
if isinstance(size, numbers.Number):
if size < 0:
raise ValueError(
"If input_size is a single number, it must be positive."
)
size = (size, size)
elif (
isinstance(size, list)
or isinstance(size, tuple)
or isinstance(size, np.ndarray)
):
if len(size) != 2:
raise ValueError("If input_size is a sequence, it must be of len 2.")
size = (size[0], size[1])
else:
raise Exception("input_size must in Number or list or tuple or np.ndarray")
self.size = size
self.keep_ratio = keep_ratio
self.random_rate = random_rate
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
im = data["img"]
text_polys = data["text_polys"]
if self.keep_ratio:
# 将图片短边pad到和长边一样
h, w, c = im.shape
max_h = max(h, self.size[0])
max_w = max(w, self.size[1])
im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
im_padded[:h, :w] = im.copy()
im = im_padded
text_polys = text_polys.astype(np.float32)
h, w, _ = im.shape
im = cv2.resize(im, self.size)
w_scale = self.size[0] / float(w)
h_scale = self.size[1] / float(h)
text_polys[:, :, 0] *= w_scale
text_polys[:, :, 1] *= h_scale
data["img"] = im
data["text_polys"] = text_polys
return data
def resize_image(img, short_size):
height, width, _ = img.shape
if height < width:
new_height = short_size
new_width = new_height / height * width
else:
new_width = short_size
new_height = new_width / width * height
new_height = int(round(new_height / 32) * 32)
new_width = int(round(new_width / 32) * 32)
resized_img = cv2.resize(img, (new_width, new_height))
return resized_img, (new_width / width, new_height / height)
class ResizeShortSize:
def __init__(self, short_size, resize_text_polys=True):
"""
:param size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
:return:
"""
self.short_size = short_size
self.resize_text_polys = resize_text_polys
def __call__(self, data: dict) -> dict:
"""
对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
im = data["img"]
text_polys = data["text_polys"]
h, w, _ = im.shape
short_edge = min(h, w)
if short_edge < self.short_size:
# 保证短边 >= short_size
scale = self.short_size / short_edge
im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
scale = (scale, scale)
# im, scale = resize_image(im, self.short_size)
if self.resize_text_polys:
# text_polys *= scale
text_polys[:, 0] *= scale[0]
text_polys[:, 1] *= scale[1]
data["img"] = im
data["text_polys"] = text_polys
return data
class HorizontalFlip:
def __init__(self, random_rate):
"""
:param random_rate: 随机系数
"""
self.random_rate = random_rate
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
im = data["img"]
text_polys = data["text_polys"]
flip_text_polys = text_polys.copy()
flip_im = cv2.flip(im, 1)
h, w, _ = flip_im.shape
flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
data["img"] = flip_im
data["text_polys"] = flip_text_polys
return data
class VerticalFlip:
def __init__(self, random_rate):
"""
:param random_rate: 随机系数
"""
self.random_rate = random_rate
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
if random.random() > self.random_rate:
return data
im = data["img"]
text_polys = data["text_polys"]
flip_text_polys = text_polys.copy()
flip_im = cv2.flip(im, 0)
h, w, _ = flip_im.shape
flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
data["img"] = flip_im
data["text_polys"] = flip_text_polys
return data
# -*- coding: utf-8 -*-
# @Time : 2019/12/4 18:06
# @Author : zhoujun
import numpy as np
import imgaug
import imgaug.augmenters as iaa
class AugmenterBuilder(object):
def __init__(self):
pass
def build(self, args, root=True):
if args is None or len(args) == 0:
return None
elif isinstance(args, list):
if root:
sequence = [self.build(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])(
*[self.to_tuple_if_list(a) for a in args[1:]]
)
elif isinstance(args, dict):
cls = getattr(iaa, args["type"])
return cls(**{k: self.to_tuple_if_list(v) for k, v in args["args"].items()})
else:
raise RuntimeError("unknown augmenter arg: " + str(args))
def to_tuple_if_list(self, obj):
if isinstance(obj, list):
return tuple(obj)
return obj
class IaaAugment:
def __init__(self, augmenter_args):
self.augmenter_args = augmenter_args
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
def __call__(self, data):
image = data["img"]
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
data["img"] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
line_polys = []
for poly in data["text_polys"]:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data["text_polys"] = np.array(line_polys)
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)]
)[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
import cv2
import numpy as np
np.seterr(divide="ignore", invalid="ignore")
import pyclipper
from shapely.geometry import Polygon
class MakeBorderMap:
def __init__(self, shrink_ratio=0.4, thresh_min=0.3, thresh_max=0.7):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
im = data["img"]
text_polys = data["text_polys"]
ignore_tags = data["ignore_tags"]
canvas = np.zeros(im.shape[:2], dtype=np.float32)
mask = np.zeros(im.shape[:2], dtype=np.float32)
for i in range(len(text_polys)):
if ignore_tags[i]:
continue
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
data["threshold_map"] = canvas
data["threshold_mask"] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
distance = (
polygon_shape.area
* (1 - np.power(self.shrink_ratio, 2))
/ polygon_shape.length
)
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)
)
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)
)
distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1] = np.fmax(
1
- distance_map[
ymin_valid - ymin : ymax_valid - ymax + height,
xmin_valid - xmin : xmax_valid - xmax + width,
],
canvas[ymin_valid : ymax_valid + 1, xmin_valid : xmax_valid + 1],
)
def distance(self, xs, ys, point_1, point_2):
"""
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
"""
height, width = xs.shape[:2]
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
point_1[1] - point_2[1]
)
cosin = (square_distance - square_distance_1 - square_distance_2) / (
2 * np.sqrt(square_distance_1 * square_distance_2)
)
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(
square_distance_1 * square_distance_2 * square_sin / square_distance
)
result[cosin < 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[
cosin < 0
]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result):
ex_point_1 = (
int(
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))
),
int(
round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio))
),
)
cv2.line(
result,
tuple(ex_point_1),
tuple(point_1),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0,
)
ex_point_2 = (
int(
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))
),
int(
round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio))
),
)
cv2.line(
result,
tuple(ex_point_2),
tuple(point_2),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0,
)
return ex_point_1, ex_point_2
import numpy as np
import cv2
def shrink_polygon_py(polygon, shrink_ratio):
"""
对框进行缩放,返回去的比例为1/shrink_ratio 即可
"""
cx = polygon[:, 0].mean()
cy = polygon[:, 1].mean()
polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
return polygon
def shrink_polygon_pyclipper(polygon, shrink_ratio):
from shapely.geometry import Polygon
import pyclipper
polygon_shape = Polygon(polygon)
distance = (
polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
)
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrunk = padding.Execute(-distance)
if shrunk == []:
shrunk = np.array(shrunk)
else:
shrunk = np.array(shrunk[0]).reshape(-1, 2)
return shrunk
class MakeShrinkMap:
r"""
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
"""
def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper"):
shrink_func_dict = {
"py": shrink_polygon_py,
"pyclipper": shrink_polygon_pyclipper,
}
self.shrink_func = shrink_func_dict[shrink_type]
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
image = data["img"]
text_polys = data["text_polys"]
ignore_tags = data["ignore_tags"]
h, w = image.shape[:2]
text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
polygon = text_polys[i]
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
shrunk = self.shrink_func(polygon, self.shrink_ratio)
if shrunk.size == 0:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
cv2.fillPoly(gt, [shrunk.astype(np.int32)], 1)
data["shrink_map"] = gt
data["shrink_mask"] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
"""
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
"""
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
for polygon in polygons:
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
for i in range(len(polygons)):
area = self.polygon_area(polygons[i])
if abs(area) < 1:
ignore_tags[i] = True
if area > 0:
polygons[i] = polygons[i][::-1, :]
return polygons, ignore_tags
def polygon_area(self, polygon):
return cv2.contourArea(polygon)
# edge = 0
# for i in range(polygon.shape[0]):
# next_index = (i + 1) % polygon.shape[0]
# edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] - polygon[i, 1])
#
# return edge / 2.
if __name__ == "__main__":
from shapely.geometry import Polygon
import pyclipper
polygon = np.array([[0, 0], [100, 10], [100, 100], [10, 90]])
a = shrink_polygon_py(polygon, 0.4)
print(a)
print(shrink_polygon_py(a, 1 / 0.4))
b = shrink_polygon_pyclipper(polygon, 0.4)
print(b)
poly = Polygon(b)
distance = poly.area * 1.5 / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(b, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
bounding_box = cv2.minAreaRect(expanded)
points = cv2.boxPoints(bounding_box)
print(points)
import random
import cv2
import numpy as np
# random crop algorithm similar to https://github.com/argman/EAST
class EastRandomCropData:
def __init__(
self,
size=(640, 640),
max_tries=50,
min_crop_side_ratio=0.1,
require_original_image=False,
keep_ratio=True,
):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.require_original_image = require_original_image
self.keep_ratio = keep_ratio
def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'text_polys':,'texts':,'ignore_tags':}
:return:
"""
im = data["img"]
text_polys = data["text_polys"]
ignore_tags = data["ignore_tags"]
texts = data["texts"]
all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
if len(im.shape) == 3:
padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
else:
padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
padimg[:h, :w] = cv2.resize(
im[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], (w, h)
)
img = padimg
else:
img = cv2.resize(
im[crop_y : crop_y + crop_h, crop_x : crop_x + crop_w], tuple(self.size)
)
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data["img"] = img
data["text_polys"] = np.float32(text_polys_crop)
data["ignore_tags"] = ignore_tags_crop
data["texts"] = texts_crop
return data
def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(self, axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(self, axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(self, regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(self, im, text_polys):
h, w = im.shape[:2]
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = self.split_regions(h_axis)
w_regions = self.split_regions(w_axis)
for i in range(self.max_tries):
if len(w_regions) > 1:
xmin, xmax = self.region_wise_random_select(w_regions, w)
else:
xmin, xmax = self.random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = self.region_wise_random_select(h_regions, h)
else:
ymin, ymax = self.random_select(h_axis, h)
if (
xmax - xmin < self.min_crop_side_ratio * w
or ymax - ymin < self.min_crop_side_ratio * h
):
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not self.is_poly_outside_rect(
poly, xmin, ymin, xmax - xmin, ymax - ymin
):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
class PSERandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, data):
imgs = data["imgs"]
h, w = imgs[0].shape[0:2]
th, tw = self.size
if w == tw and h == th:
return imgs
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
# 文本实例的左上角点
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
tl[tl < 0] = 0
# 文本实例的右下角点
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
br[br < 0] = 0
# 保证选到右下角点时,有足够的距离进行crop
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
for _ in range(50000):
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
if imgs[1][i : i + th, j : j + tw].sum() <= 0:
continue
else:
break
else:
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
# return i, j, th, tw
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
imgs[idx] = imgs[idx][i : i + th, j : j + tw, :]
else:
imgs[idx] = imgs[idx][i : i + th, j : j + tw]
data["imgs"] = imgs
return data
name: dbnet
channels:
- conda-forge
- defaults
dependencies:
- anyconfig==0.9.10
- future==0.18.2
- imgaug==0.4.0
- matplotlib==3.1.2
- numpy==1.17.4
- opencv
- pyclipper
- PyYAML==5.2
- scikit-image==0.16.2
- Shapely==1.6.4
- tensorboard=2
- tqdm==4.40.1
- ipython
- pip
- pip:
- polygon3
CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py --model_path ''
#Only use if your file names of the images and txts are identical
rm ./datasets/train_img.txt
rm ./datasets/train_gt.txt
rm ./datasets/test_img.txt
rm ./datasets/test_gt.txt
rm ./datasets/train.txt
rm ./datasets/test.txt
ls ./datasets/train/img/*.jpg > ./datasets/train_img.txt
ls ./datasets/train/gt/*.txt > ./datasets/train_gt.txt
ls ./datasets/test/img/*.jpg > ./datasets/test_img.txt
ls ./datasets/test/gt/*.txt > ./datasets/test_gt.txt
paste ./datasets/train_img.txt ./datasets/train_gt.txt > ./datasets/train.txt
paste ./datasets/test_img.txt ./datasets/test_gt.txt > ./datasets/test.txt
rm ./datasets/train_img.txt
rm ./datasets/train_gt.txt
rm ./datasets/test_img.txt
rm ./datasets/test_gt.txt
# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:55
# @Author : zhoujun
import copy
from .model import Model
from .losses import build_loss
__all__ = ["build_loss", "build_model"]
support_model = ["Model"]
def build_model(config):
"""
get architecture model class
"""
copy_config = copy.deepcopy(config)
arch_type = copy_config.pop("type")
assert (
arch_type in support_model
), f"{arch_type} is not developed yet!, only {support_model} are support now"
arch_model = eval(arch_type)(copy_config)
return arch_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment