Commit f1506916 authored by sugon_cxj's avatar sugon_cxj
Browse files

first commit

parent 55c28ed5
Pipeline #266 canceled with stages
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
import numpy as np
import string
from shapely.geometry import LineString, Point, Polygon
import json
import copy
from scipy.spatial import distance as dist
from ppocr.utils.logging import get_logger
class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
def __call__(self, data):
label = data['label']
if label not in self.label_list:
return None
label = self.label_list.index(label)
data['label'] = label
return data
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
if len(boxes) == 0:
return None
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.lower = False
if character_dict_path is None:
logger = get_logger()
logger.warning(
"The character_dict_path is None, model can only recognize number and lower letters"
)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
self.lower = True
else:
self.character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
if self.lower:
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
# logger = get_logger()
# logger.warning('{} is not in dict'.format(char))
continue
text_list.append(self.dict[char])
if len(text_list) == 0:
return None
return text_list
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(CTCLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
label = [0] * len(self.character)
for x in text:
label[x] += 1
data['label_ace'] = np.array(label)
return data
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
class SARLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(SARLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
def get_ignored_tokens(self):
return [self.padding_idx]
class MultiLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(MultiLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
use_space_char, **kwargs)
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_out = dict()
data_out['img_path'] = data.get('img_path', None)
data_out['image'] = data['image']
ctc = self.ctc_encode.__call__(data_ctc)
sar = self.sar_encode.__call__(data_sar)
if ctc is None or sar is None:
return None
data_out['label_ctc'] = ctc['label']
data_out['label_sar'] = sar['label']
data_out['length'] = ctc['length']
return data_out
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
np.seterr(divide='ignore', invalid='ignore')
import pyclipper
from shapely.geometry import Polygon
import sys
import warnings
warnings.simplefilter("ignore")
__all__ = ['MakeBorderMap']
class MakeBorderMap(object):
def __init__(self,
shrink_ratio=0.4,
thresh_min=0.3,
thresh_max=0.7,
**kwargs):
self.shrink_ratio = shrink_ratio
self.thresh_min = thresh_min
self.thresh_max = thresh_max
def __call__(self, data):
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
canvas = np.zeros(img.shape[:2], dtype=np.float32)
mask = np.zeros(img.shape[:2], dtype=np.float32)
for i in range(len(text_polys)):
if ignore_tags[i]:
continue
self.draw_border_map(text_polys[i], canvas, mask=mask)
canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
data['threshold_map'] = canvas
data['threshold_mask'] = mask
return data
def draw_border_map(self, polygon, canvas, mask):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
distance = polygon_shape.area * (
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(
0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
np.linspace(
0, height - 1, num=height).reshape(height, 1), (height, width))
distance_map = np.zeros(
(polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
xmin_valid - xmin:xmax_valid - xmax + width],
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
def _distance(self, xs, ys, point_1, point_2):
'''
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
'''
height, width = xs.shape[:2]
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(
point_1[1] - point_2[1])
cosin = (square_distance - square_distance_1 - square_distance_2) / (
2 * np.sqrt(square_distance_1 * square_distance_2))
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
square_distance)
result[cosin <
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
< 0]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(self, point_1, point_2, result, shrink_ratio):
ex_point_1 = (int(
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
int(
round(point_1[1] + (point_1[1] - point_2[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_1),
tuple(point_1),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
ex_point_2 = (int(
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
int(
round(point_2[1] + (point_2[1] - point_1[1]) * (
1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_2),
tuple(point_2),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
return ex_point_1, ex_point_2
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
__all__ = ['MakePseGt']
class MakePseGt(object):
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
self.kernel_num = kernel_num
self.min_shrink_ratio = min_shrink_ratio
self.size = size
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w, _ = image.shape
short_edge = min(h, w)
if short_edge < self.size:
# keep short_size >= self.size
scale = self.size / short_edge
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
text_polys *= scale
gt_kernels = []
for i in range(1, self.kernel_num + 1):
# s1->sn, from big to small
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1
) * i
text_kernel, ignore_tags = self.generate_kernel(
image.shape[0:2], rate, text_polys, ignore_tags)
gt_kernels.append(text_kernel)
training_mask = np.ones(image.shape[0:2], dtype='uint8')
for i in range(text_polys.shape[0]):
if ignore_tags[i]:
cv2.fillPoly(training_mask,
text_polys[i].astype(np.int32)[np.newaxis, :, :],
0)
gt_kernels = np.array(gt_kernels)
gt_kernels[gt_kernels > 0] = 1
data['image'] = image
data['polys'] = text_polys
data['gt_kernels'] = gt_kernels[0:]
data['gt_text'] = gt_kernels[0]
data['mask'] = training_mask.astype('float32')
return data
def generate_kernel(self,
img_size,
shrink_ratio,
text_polys,
ignore_tags=None):
"""
Refer to part of the code:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/base_textdet_targets.py
"""
h, w = img_size
text_kernel = np.zeros((h, w), dtype=np.float32)
for i, poly in enumerate(text_polys):
polygon = Polygon(poly)
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (
polygon.length + 1e-6)
subject = [tuple(l) for l in poly]
pco = pyclipper.PyclipperOffset()
pco.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrinked = np.array(pco.Execute(-distance))
if len(shrinked) == 0 or shrinked.size == 0:
if ignore_tags is not None:
ignore_tags[i] = True
continue
try:
shrinked = np.array(shrinked[0]).reshape(-1, 2)
except:
if ignore_tags is not None:
ignore_tags[i] = True
continue
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
return text_kernel, ignore_tags
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
__all__ = ['MakeShrinkMap']
class MakeShrinkMap(object):
r'''
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
'''
def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
self.min_text_size = min_text_size
self.shrink_ratio = shrink_ratio
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w = image.shape[:2]
text_polys, ignore_tags = self.validate_polygons(text_polys,
ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
polygon = text_polys[i]
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
if ignore_tags[i] or min(height, width) < self.min_text_size:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
polygon_shape = Polygon(polygon)
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
# Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1,
self.shrink_ratio)
np.append(possible_ratios, 1)
# print(possible_ratios)
for ratio in possible_ratios:
# print(f"Change shrink ratio to {ratio}")
distance = polygon_shape.area * (
1 - np.power(ratio, 2)) / polygon_shape.length
shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
break
if shrinked == []:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
for each_shirnk in shrinked:
shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
data['shrink_map'] = gt
data['shrink_mask'] = mask
return data
def validate_polygons(self, polygons, ignore_tags, h, w):
'''
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
'''
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
for polygon in polygons:
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
for i in range(len(polygons)):
area = self.polygon_area(polygons[i])
if abs(area) < 1:
ignore_tags[i] = True
if area > 0:
polygons[i] = polygons[i][::-1, :]
return polygons, ignore_tags
def polygon_area(self, polygon):
"""
compute polygon area
"""
area = 0
q = polygon[-1]
for p in polygon:
area += p[0] * q[1] - p[1] * q[0]
q = p
return area / 2.0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
import math
class DecodeImage(object):
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class DetResizeForSingle(object):
def __init__(self, **kwargs):
super(DetResizeForSingle, self).__init__()
def __call__(self, data):
# print("DetResizeForSingle")
img = data['image']
src_h, src_w, _ = img.shape
img, [ratio_h, ratio_w], [resize_h, resize_w] = self.resize_image(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w, resize_h, resize_w])
return data
def resize_image(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
h, w, c = img.shape
if h > w:
resize_h = 1280
ratio = float(1280) / h
resize_w = int(w * ratio)
else:
resize_w = 1280
ratio = float(1280) / w
resize_h = int(h * ratio)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
img_pd_h = 1280
img_pd_w = 1280
padding_im = np.zeros((img_pd_h, img_pd_w, 3), dtype=np.uint8)
top = int((img_pd_h - resize_h) / 2)
left = int((img_pd_w -resize_w) / 2)
padding_im[top:top + int(resize_h), left:left + int(resize_w), :] = img
ratio_h = img_pd_h / float(h)
ratio_w = img_pd_w / float(w)
return padding_im, [ratio_h, ratio_w], [resize_h, resize_w]
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
import six
class RawRandAugment(object):
def __init__(self,
num_layers=2,
magnitude=5,
fillcolor=(128, 128, 128),
**kwargs):
self.num_layers = num_layers
self.magnitude = magnitude
self.max_level = 10
abso_level = self.magnitude / self.max_level
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 150.0 / 331 * abso_level,
"translateY": 150.0 / 331 * abso_level,
"rotate": 30 * abso_level,
"color": 0.9 * abso_level,
"posterize": int(4.0 * abso_level),
"solarize": 256.0 * abso_level,
"contrast": 0.9 * abso_level,
"sharpness": 0.9 * abso_level,
"brightness": 0.9 * abso_level,
"autocontrast": 0,
"equalize": 0,
"invert": 0
}
# from https://stackoverflow.com/questions/5252170/
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
rnd_ch_op = random.choice
self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, magnitude:
ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
def __call__(self, img):
avaiable_op_names = list(self.level_map.keys())
for layer_num in range(self.num_layers):
op_name = np.random.choice(avaiable_op_names)
img = self.func[op_name](img, self.level_map[op_name])
return img
class RandAugment(RawRandAugment):
""" RandAugment wrapper to auto fit different img types """
def __init__(self, prob=0.5, *args, **kwargs):
self.prob = prob
if six.PY2:
super(RandAugment, self).__init__(*args, **kwargs)
else:
super().__init__(*args, **kwargs)
def __call__(self, data):
if np.random.rand() > self.prob:
return data
img = data['image']
if not isinstance(img, Image.Image):
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
if six.PY2:
img = super(RandAugment, self).__call__(img)
else:
img = super().__call__(img)
if isinstance(img, Image.Image):
img = np.asarray(img)
data['image'] = img
return data
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
import random
def is_poly_in_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
h, w, _ = im.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = split_regions(h_axis)
w_regions = split_regions(w_axis)
for i in range(max_tries):
if len(w_regions) > 1:
xmin, xmax = region_wise_random_select(w_regions, w)
else:
xmin, xmax = random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = region_wise_random_select(h_regions, h)
else:
ymin, ymax = random_select(h_axis, h)
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
class EastRandomCropData(object):
def __init__(self,
size=(640, 640),
max_tries=10,
min_crop_side_ratio=0.1,
keep_ratio=True,
**kwargs):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.keep_ratio = keep_ratio
def __call__(self, data):
img = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(
img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
img.dtype)
padimg[:h, :w] = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(self.size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.array(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
class RandomCropImgMask(object):
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
self.size = size
self.main_key = main_key
self.crop_keys = crop_keys
self.p = p
def __call__(self, data):
image = data['image']
h, w = image.shape[0:2]
th, tw = self.size
if w == tw and h == th:
return data
mask = data[self.main_key]
if np.max(mask) > 0 and random.random() > self.p:
# make sure to crop the text region
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
tl[tl < 0] = 0
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
br[br < 0] = 0
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
else:
i = random.randint(0, h - th) if h - th > 0 else 0
j = random.randint(0, w - tw) if w - tw > 0 else 0
# return i, j, th, tw
for k in data:
if k in self.crop_keys:
if len(data[k].shape) == 3:
if np.argmin(data[k].shape) == 0:
img = data[k][:, i:i + th, j:j + tw]
if img.shape[1] != img.shape[2]:
a = 1
elif np.argmin(data[k].shape) == 2:
img = data[k][i:i + th, j:j + tw, :]
if img.shape[1] != img.shape[0]:
a = 1
else:
img = data[k]
else:
img = data[k][i:i + th, j:j + tw]
if img.shape[0] != img.shape[1]:
a = 1
data[k] = img
return data
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
import copy
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object):
def __init__(self, use_tia=True, aug_prob=0.4, **kwargs):
self.use_tia = use_tia
self.aug_prob = aug_prob
def __call__(self, data):
img = data['image']
img = warp(img, 10, self.use_tia, self.aug_prob)
data['image'] = img
return data
class RecConAug(object):
def __init__(self,
prob=0.5,
image_shape=(32, 320, 3),
max_text_length=25,
ext_data_num=1,
**kwargs):
self.ext_data_num = ext_data_num
self.prob = prob
self.max_text_length = max_text_length
self.image_shape = image_shape
self.max_wh_ratio = self.image_shape[1] / self.image_shape[0]
def merge_ext_data(self, data, ext_data):
ori_w = round(data['image'].shape[1] / data['image'].shape[0] *
self.image_shape[0])
ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] *
self.image_shape[0])
data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0]))
ext_data['image'] = cv2.resize(ext_data['image'],
(ext_w, self.image_shape[0]))
data['image'] = np.concatenate(
[data['image'], ext_data['image']], axis=1)
data["label"] += ext_data["label"]
return data
def __call__(self, data):
rnd_num = random.random()
if rnd_num > self.prob:
return data
for idx, ext_data in enumerate(data["ext_data"]):
if len(data["label"]) + len(ext_data[
"label"]) > self.max_text_length:
break
concat_ratio = data['image'].shape[1] / data['image'].shape[
0] + ext_data['image'].shape[1] / ext_data['image'].shape[0]
if concat_ratio > self.max_wh_ratio:
break
data = self.merge_ext_data(data, ext_data)
data.pop("ext_data")
return data
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
norm_img, _ = resize_norm_img(img, self.image_shape)
data['image'] = norm_img
return data
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
resized_image = norm_img.astype(np.float32) / 128. - 1.
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
data['image'] = padding_im
return data
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
if self.infer_mode and self.character_dict_path is not None:
norm_img, valid_ratio = resize_norm_img_chinese(img,
self.image_shape)
else:
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
class SARRecResizeImg(object):
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
return data
class PRENResizeImg(object):
def __init__(self, image_shape, **kwargs):
"""
Accroding to original paper's realization, it's a hard resize method here.
So maybe you should optimize it to fit for your task better.
"""
self.dst_h, self.dst_w = image_shape
def __call__(self, data):
img = data['image']
resized_img = cv2.resize(
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
resized_img = resized_img.transpose((2, 0, 1)) / 255
resized_img -= 0.5
resized_img /= 0.5
data['image'] = resized_img.astype(np.float32)
return data
class SVTRRecResizeImg(object):
def __init__(self, image_shape, padding=True, **kwargs):
self.image_shape = image_shape
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img, valid_ratio = resize_norm_img(img, self.image_shape,
self.padding)
data['image'] = norm_img
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(imgH * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
valid_ratio = min(1.0, float(resized_w / imgW))
return padding_im, valid_ratio
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag():
"""
flag
"""
return 1 if random.random() > 0.5000001 else -1
def cvtColor(img):
"""
cvtColor
"""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
delta = 0.001 * random.random() * flag()
hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta)
new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return new_img
def blur(img):
"""
blur
"""
h, w, _ = img.shape
if h > 10 and w > 10:
return cv2.GaussianBlur(img, (5, 5), 1)
else:
return img
def jitter(img):
"""
jitter
"""
w, h, _ = img.shape
if h > 10 and w > 10:
thres = min(w, h)
s = int(random.random() * thres * 0.01)
src_img = img.copy()
for i in range(s):
img[i:, i:, :] = src_img[:w - i, :h - i, :]
return img
else:
return img
def add_gasuss_noise(image, mean=0, var=0.1):
"""
Gasuss noise
"""
noise = np.random.normal(mean, var**0.5, image.shape)
out = image + 0.5 * noise
out = np.clip(out, 0, 255)
out = np.uint8(out)
return out
def get_crop(image):
"""
random crop
"""
h, w, _ = image.shape
top_min = 1
top_max = 8
top_crop = int(random.randint(top_min, top_max))
top_crop = min(top_crop, h - 1)
crop_img = image.copy()
ratio = random.randint(0, 1)
if ratio:
crop_img = crop_img[top_crop:h, :, :]
else:
crop_img = crop_img[0:h - top_crop, :, :]
return crop_img
class Config:
"""
Config
"""
def __init__(self, use_tia):
self.anglex = random.random() * 30
self.angley = random.random() * 15
self.anglez = random.random() * 10
self.fov = 42
self.r = 0
self.shearx = random.random() * 0.3
self.sheary = random.random() * 0.05
self.borderMode = cv2.BORDER_REPLICATE
self.use_tia = use_tia
def make(self, w, h, ang):
"""
make
"""
self.anglex = random.random() * 5 * flag()
self.angley = random.random() * 5 * flag()
self.anglez = -1 * random.random() * int(ang) * flag()
self.fov = 42
self.r = 0
self.shearx = 0
self.sheary = 0
self.borderMode = cv2.BORDER_REPLICATE
self.w = w
self.h = h
self.perspective = self.use_tia
self.stretch = self.use_tia
self.distort = self.use_tia
self.crop = True
self.affine = False
self.reverse = True
self.noise = True
self.jitter = True
self.blur = True
self.color = True
def rad(x):
"""
rad
"""
return x * np.pi / 180
def get_warpR(config):
"""
get_warpR
"""
anglex, angley, anglez, fov, w, h, r = \
config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r
if w > 69 and w < 112:
anglex = anglex * 1.5
z = np.sqrt(w**2 + h**2) / 2 / np.tan(rad(fov / 2))
# Homogeneous coordinate transformation matrix
rx = np.array([[1, 0, 0, 0],
[0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [
0,
-np.sin(rad(anglex)),
np.cos(rad(anglex)),
0,
], [0, 0, 0, 1]], np.float32)
ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
[0, 1, 0, 0], [
-np.sin(rad(angley)),
0,
np.cos(rad(angley)),
0,
], [0, 0, 0, 1]], np.float32)
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
[0, 0, 1, 0], [0, 0, 0, 1]], np.float32)
r = rx.dot(ry).dot(rz)
# generate 4 points
pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
p3 = np.array([0, h, 0, 0], np.float32) - pcenter
p4 = np.array([w, h, 0, 0], np.float32) - pcenter
dst1 = r.dot(p1)
dst2 = r.dot(p2)
dst3 = r.dot(p3)
dst4 = r.dot(p4)
list_dst = np.array([dst1, dst2, dst3, dst4])
org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
dst = np.zeros((4, 2), np.float32)
# Project onto the image plane
dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0]
dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1]
warpR = cv2.getPerspectiveTransform(org, dst)
dst1, dst2, dst3, dst4 = dst
r1 = int(min(dst1[1], dst2[1]))
r2 = int(max(dst3[1], dst4[1]))
c1 = int(min(dst1[0], dst3[0]))
c2 = int(max(dst2[0], dst4[0]))
try:
ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1))
dx = -c1
dy = -r1
T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]])
ret = T1.dot(warpR)
except:
ratio = 1.0
T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
ret = T1
return ret, (-r1, -c1), ratio, dst
def get_warpAffine(config):
"""
get_warpAffine
"""
anglez = config.anglez
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
return rz
def warp(img, ang, use_tia=True, prob=0.4):
"""
warp
"""
h, w, _ = img.shape
config = Config(use_tia=use_tia)
config.make(w, h, ang)
new_img = img
if config.distort:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_distort(new_img, random.randint(3, 6))
if config.stretch:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_stretch(new_img, random.randint(3, 6))
if config.perspective:
if random.random() <= prob:
new_img = tia_perspective(new_img)
if config.crop:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img)
if config.blur:
if random.random() <= prob:
new_img = blur(new_img)
if config.color:
if random.random() <= prob:
new_img = cvtColor(new_img)
if config.jitter:
new_img = jitter(new_img)
if config.noise:
if random.random() <= prob:
new_img = add_gasuss_noise(new_img)
if config.reverse:
if random.random() <= prob:
new_img = 255 - new_img
return new_img
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
from PIL import Image
from .rec_img_aug import resize_norm_img
class SSLRotateResize(object):
def __init__(self,
image_shape,
padding=False,
select_all=True,
mode="train",
**kwargs):
self.image_shape = image_shape
self.padding = padding
self.select_all = select_all
self.mode = mode
def __call__(self, data):
img = data["image"]
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
data["image_r180"] = cv2.rotate(data["image_r90"],
cv2.ROTATE_90_CLOCKWISE)
data["image_r270"] = cv2.rotate(data["image_r180"],
cv2.ROTATE_90_CLOCKWISE)
images = []
for key in ["image", "image_r90", "image_r180", "image_r270"]:
images.append(
resize_norm_img(
data.pop(key),
image_shape=self.image_shape,
padding=self.padding)[0])
data["image"] = np.stack(images, axis=0)
data["label"] = np.array(list(range(4)))
if not self.select_all:
data["image"] = data["image"][0::2] # just choose 0 and 180
data["label"] = data["label"][0:2] # label needs to be continuous
if self.mode == "test":
data["image"] = data["image"][0]
data["label"] = data["label"][0]
return data
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
"""
import numpy as np
from .warp_mls import WarpMLS
def tia_distort(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_stretch(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_perspective(src):
img_h, img_w = src.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
"""
import numpy as np
class WarpMLS:
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
self.src = src
self.src_pts = src_pts
self.dst_pts = dst_pts
self.pt_count = len(self.dst_pts)
self.dst_w = dst_w
self.dst_h = dst_h
self.trans_ratio = trans_ratio
self.grid_size = 100
self.rdx = np.zeros((self.dst_h, self.dst_w))
self.rdy = np.zeros((self.dst_h, self.dst_w))
@staticmethod
def __bilinear_interp(x, y, v11, v12, v21, v22):
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
def generate(self):
self.calc_delta()
return self.gen_img()
def calc_delta(self):
w = np.zeros(self.pt_count, dtype=np.float32)
if self.pt_count < 2:
return
i = 0
while 1:
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
i = self.dst_w - 1
elif i >= self.dst_w:
break
j = 0
while 1:
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
j = self.dst_h - 1
elif j >= self.dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
break
w[k] = 1. / (
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(self.dst_pts[k])
swq = swq + w[k] * np.array(self.src_pts[k])
if k == self.pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = self.src_pts[k]
self.rdx[j, i] = new_pt[0] - i
self.rdy[j, i] = new_pt[1] - j
j += self.grid_size
i += self.grid_size
def gen_img(self):
src_h, src_w = self.src.shape[:2]
dst = np.zeros_like(self.src, dtype=np.float32)
for i in np.arange(0, self.dst_h, self.grid_size):
for j in np.arange(0, self.dst_w, self.grid_size):
ni = i + self.grid_size
nj = j + self.grid_size
w = h = self.grid_size
if ni >= self.dst_h:
ni = self.dst_h - 1
h = ni - i + 1
if nj >= self.dst_w:
nj = self.dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self.__bilinear_interp(
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
self.rdx[ni, j], self.rdx[ni, nj])
delta_y = self.__bilinear_interp(
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
self.rdy[ni, j], self.rdy[ni, nj])
nx = j + dj + delta_x * self.trans_ratio
ny = i + di + delta_y * self.trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(self.src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h, j:j + w] = self.__bilinear_interp(
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
self.src[nyi1, nxi], self.src[nyi1, nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
import lmdb
import cv2
from .imaug import transform, create_operators
class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(LMDBDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
logger.info("Initialize indexs of datasets:%s" % data_dir)
self.data_idx_order_list = self.dataset_traversal()
if self.do_shuffle:
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def dataset_traversal(self):
lmdb_num = len(self.lmdb_sets)
total_sample_num = 0
for lno in range(lmdb_num):
total_sample_num += self.lmdb_sets[lno]['num_samples']
data_idx_order_list = np.zeros((total_sample_num, 2))
beg_idx = 0
for lno in range(lmdb_num):
tmp_sample_num = self.lmdb_sets[lno]['num_samples']
end_idx = beg_idx + tmp_sample_num
data_idx_order_list[beg_idx:end_idx, 0] = lno
data_idx_order_list[beg_idx:end_idx, 1] \
= list(range(tmp_sample_num))
data_idx_order_list[beg_idx:end_idx, 1] += 1
beg_idx = beg_idx + tmp_sample_num
return data_idx_order_list
def get_img_data(self, value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(self.__len__())]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
continue
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, load_data_ops)
ext_data.append(data)
return ext_data
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
return imgbuf, label
def __getitem__(self, idx):
lmdb_idx, file_idx = self.data_idx_order_list[idx]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info
data = {'image': img, 'label': label}
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return self.data_idx_order_list.shape[0]
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PGDataSet, self).__init__()
self.logger = logger
self.seed = seed
self.mode = mode
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
img_id = 0
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
try:
img_id = int(data_line.split(".")[0][7:])
except:
img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import json
import random
import traceback
from paddle.io import Dataset
from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", 1.0)
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
2)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def shuffle_data_random(self):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
))]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
if 'polys' in data.keys():
if data['polys'].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, traceback.format_exc()))
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return len(self.data_idx_order_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
from .kie_sdmgr_loss import SDMGRLoss
# basic loss function
from .basic_loss import DistanceLoss
# combined loss function
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
support_dict))
module_class = eval(module_name)(**config)
return module_class
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment