Commit aad3093a authored by WenmuZhou's avatar WenmuZhou
Browse files

dygraph first commit

parent 10f7e519
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class DecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class keepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
if self.resize_type == 0:
img, shape = self.resize_image_type0(img)
else:
img, shape = self.resize_image_type1(img)
data['image'] = img
data['shape'] = shape
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, np.array([ori_h, ori_w])
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
else:
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
return img, np.array([h, w])
...@@ -108,48 +108,103 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries): ...@@ -108,48 +108,103 @@ def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
return 0, 0, w, h return 0, 0, w, h
def RandomCropData(data, size): class EastRandomCropData(object):
max_tries = 10 def __init__(self,
min_crop_side_ratio = 0.1 size=(640, 640),
require_original_image = False max_tries=10,
keep_ratio = True min_crop_side_ratio=0.1,
keep_ratio=True,
im = data['image'] **kwargs):
text_polys = data['polys'] self.size = size
ignore_tags = data['ignore_tags'] self.max_tries = max_tries
texts = data['texts'] self.min_crop_side_ratio = min_crop_side_ratio
all_care_polys = [ self.keep_ratio = keep_ratio
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
] def __call__(self, data):
# 计算crop区域 img = data['image']
crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys, text_polys = data['polys']
min_crop_side_ratio, max_tries) ignore_tags = data['ignore_tags']
# crop 图片 保持比例填充 texts = data['texts']
scale_w = size[0] / crop_w all_care_polys = [
scale_h = size[1] / crop_h text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
scale = min(scale_w, scale_h) ]
h = int(crop_h * scale) # 计算crop区域
w = int(crop_w * scale) crop_x, crop_y, crop_w, crop_h = crop_area(
if keep_ratio: img, all_care_polys, self.min_crop_side_ratio, self.max_tries)
padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype) # crop 图片 保持比例填充
padimg[:h, :w] = cv2.resize( scale_w = self.size[0] / crop_w
im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) scale_h = self.size[1] / crop_h
img = padimg scale = min(scale_w, scale_h)
else: h = int(crop_h * scale)
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], w = int(crop_w * scale)
tuple(size)) if self.keep_ratio:
# crop 文本框 padimg = np.zeros((self.size[1], self.size[0], img.shape[2]),
text_polys_crop = [] img.dtype)
ignore_tags_crop = [] padimg[:h, :w] = cv2.resize(
texts_crop = [] img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
for poly, text, tag in zip(text_polys, texts, ignore_tags): img = padimg
poly = ((poly - (crop_x, crop_y)) * scale).tolist() else:
if not is_poly_outside_rect(poly, 0, 0, w, h): img = cv2.resize(
text_polys_crop.append(poly) img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
ignore_tags_crop.append(tag) tuple(self.size))
texts_crop.append(text) # crop 文本框
data['image'] = img text_polys_crop = []
data['polys'] = np.array(text_polys_crop) ignore_tags_crop = []
data['ignore_tags'] = ignore_tags_crop texts_crop = []
data['texts'] = texts_crop for poly, text, tag in zip(text_polys, texts, ignore_tags):
return data poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.array(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
class PSERandomCrop(object):
def __init__(self, size, **kwargs):
self.size = size
def __call__(self, data):
imgs = data['imgs']
h, w = imgs[0].shape[0:2]
th, tw = self.size
if w == tw and h == th:
return imgs
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
# 文本实例的左上角点
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
tl[tl < 0] = 0
# 文本实例的右下角点
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
br[br < 0] = 0
# 保证选到右下角点时,有足够的距离进行crop
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
for _ in range(50000):
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
continue
else:
break
else:
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
# return i, j, th, tw
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
else:
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
data['imgs'] = imgs
return data
This diff is collapsed.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .augment import tia_perspective, tia_distort, tia_stretch
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from .warp_mls import WarpMLS
def tia_distort(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_stretch(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_perspective(src):
img_h, img_w = src.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment