Commit a7a899f6 authored by myhloli's avatar myhloli
Browse files

feat(model): add OCR model base structure and utilities

- Add base model structure for OCR in pytorch
- Implement data augmentation and transformation modules
- Create utilities for dictionary handling and state dict conversion
- Include post-processing modules for OCR
- Add weight initialization and loading functions
parent 72e66c2d
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
import cv2
import numpy as np
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
def img_decode(content: bytes):
np_arr = np.frombuffer(content, dtype=np.uint8)
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
def check_img(img):
if isinstance(img, bytes):
img = img_decode(img)
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
return img
def alpha_to_color(img, alpha_color=(255, 255, 255)):
if len(img.shape) == 3 and img.shape[2] == 4:
B, G, R, A = cv2.split(img)
alpha = A / 255
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
img = cv2.merge((B, G, R))
return img
def preprocess_image(_image):
alpha_color = (255, 255, 255)
_image = alpha_to_color(_image, alpha_color)
return _image
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def bbox_to_points(bbox):
""" 将bbox格式转换为四个顶点的数组 """
x0, y0, x1, y1 = bbox
return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
def points_to_bbox(points):
""" 将四个顶点的数组转换为bbox格式 """
x0, y0 = points[0]
x1, _ = points[1]
_, y1 = points[2]
return [x0, y0, x1, y1]
def merge_intervals(intervals):
# Sort the intervals based on the start value
intervals.sort(key=lambda x: x[0])
merged = []
for interval in intervals:
# If the list of merged intervals is empty or if the current
# interval does not overlap with the previous, simply append it.
if not merged or merged[-1][1] < interval[0]:
merged.append(interval)
else:
# Otherwise, there is overlap, so we merge the current and previous intervals.
merged[-1][1] = max(merged[-1][1], interval[1])
return merged
def remove_intervals(original, masks):
# Merge all mask intervals
merged_masks = merge_intervals(masks)
result = []
original_start, original_end = original
for mask in merged_masks:
mask_start, mask_end = mask
# If the mask starts after the original range, ignore it
if mask_start > original_end:
continue
# If the mask ends before the original range starts, ignore it
if mask_end < original_start:
continue
# Remove the masked part from the original range
if original_start < mask_start:
result.append([original_start, mask_start - 1])
original_start = max(mask_end + 1, original_start)
# Add the remaining part of the original range, if any
if original_start <= original_end:
result.append([original_start, original_end])
return result
def update_det_boxes(dt_boxes, mfd_res):
new_dt_boxes = []
angle_boxes_list = []
for text_box in dt_boxes:
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box)
continue
text_bbox = points_to_bbox(text_box)
masks_list = []
for mf_box in mfd_res:
mf_bbox = mf_box['bbox']
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
masks_list.append([mf_bbox[0], mf_bbox[2]])
text_x_range = [text_bbox[0], text_bbox[2]]
text_remove_mask_range = remove_intervals(text_x_range, masks_list)
temp_dt_box = []
for text_remove_mask in text_remove_mask_range:
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
if len(temp_dt_box) > 0:
new_dt_boxes.extend(temp_dt_box)
new_dt_boxes.extend(angle_boxes_list)
return new_dt_boxes
def merge_overlapping_spans(spans):
"""
Merges overlapping spans on the same line.
:param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
:return: A list of merged spans
"""
# Return an empty list if the input spans list is empty
if not spans:
return []
# Sort spans by their starting x-coordinate
spans.sort(key=lambda x: x[0])
# Initialize the list of merged spans
merged = []
for span in spans:
# Unpack span coordinates
x1, y1, x2, y2 = span
# If the merged list is empty or there's no horizontal overlap, add the span directly
if not merged or merged[-1][2] < x1:
merged.append(span)
else:
# If there is horizontal overlap, merge the current span with the previous one
last_span = merged.pop()
# Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
x1 = min(last_span[0], x1)
y1 = min(last_span[1], y1)
x2 = max(last_span[2], x2)
y2 = max(last_span[3], y2)
# Add the merged span back to the list
merged.append((x1, y1, x2, y2))
# Return the list of merged spans
return merged
def merge_det_boxes(dt_boxes):
"""
Merge detection boxes.
This function takes a list of detected bounding boxes, each represented by four corner points.
The goal is to merge these bounding boxes into larger text regions.
Parameters:
dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
Returns:
list: A list containing the merged text regions, where each region is represented by four corner points.
"""
# Convert the detection boxes into a dictionary format with bounding boxes and type
dt_boxes_dict_list = []
angle_boxes_list = []
for text_box in dt_boxes:
text_bbox = points_to_bbox(text_box)
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box)
continue
text_box_dict = {
'bbox': text_bbox,
'type': 'text',
}
dt_boxes_dict_list.append(text_box_dict)
# Merge adjacent text regions into lines
lines = merge_spans_to_line(dt_boxes_dict_list)
# Initialize a new list for storing the merged text regions
new_dt_boxes = []
for line in lines:
line_bbox_list = []
for span in line:
line_bbox_list.append(span['bbox'])
# Merge overlapping text regions within the same line
merged_spans = merge_overlapping_spans(line_bbox_list)
# Convert the merged text regions back to point format and add them to the new detection box list
for span in merged_spans:
new_dt_boxes.append(bbox_to_points(span))
new_dt_boxes.extend(angle_boxes_list)
return new_dt_boxes
def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
return adjusted_mfdetrec_res
def get_ocr_result_list(ocr_res, useful_list):
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
ocr_result_list = []
for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# logger.info(f"text: {text}, score: {score}")
if score < 0.6: # 过滤低置信度的结果
continue
else:
p1, p2, p3, p4 = box_ocr_res
text, score = "", 1
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4]
if calculate_is_angle(poly):
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心
x_center = sum(point[0] for point in poly) / 4
y_center = sum(point[1] for point in poly) / 4
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
new_width = p3[0] - p1[0]
p1 = [x_center - new_width / 2, y_center - new_height / 2]
p2 = [x_center + new_width / 2, y_center - new_height / 2]
p3 = [x_center + new_width / 2, y_center + new_height / 2]
p4 = [x_center - new_width / 2, y_center + new_height / 2]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
ocr_result_list.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': float(round(score, 2)),
'text': text,
})
return ocr_result_list
def calculate_is_angle(poly):
p1, p2, p3, p4 = poly
height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
return False
else:
# logger.info((p3[1] - p1[1])/height)
return True
def get_rotate_crop_image(img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import copy
import cv2
import numpy as np
from loguru import logger
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import check_img, preprocess_image, sorted_boxes, \
merge_det_boxes, update_det_boxes, get_rotate_crop_image
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.tools.infer.predict_system import TextSystem
import tools.infer.pytorchocr_utility as utility
import argparse
class PytorchPaddleOCR(TextSystem):
def __init__(self, *args, **kwargs):
parser = utility.init_args()
args = parser.parse_args(args)
self.lang = kwargs.get('lang', 'ch')
# kwargs['cls_model_path'] = "/Users/myhloli/Downloads/ch_ptocr_mobile_v2.0_cls_infer.pth"
if self.lang == 'ch':
kwargs['det_model_path'] = "/Users/myhloli/Downloads/ch_ptocr_v4_det_infer.pth"
kwargs['rec_model_path'] = "/Users/myhloli/Downloads/ch_ptocr_v4_rec_infer.pth"
kwargs['det_yaml_path'] = "/Users/myhloli/Downloads/PaddleOCR2Pytorch-main/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml"
kwargs['rec_yaml_path'] = "/Users/myhloli/Downloads/PaddleOCR2Pytorch-main/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml"
kwargs['rec_image_shape'] = '3,48,320'
kwargs['device'] = get_device()
default_args = vars(args)
default_args.update(kwargs)
args = argparse.Namespace(**default_args)
super().__init__(args)
def ocr(self,
img,
det=True,
rec=True,
mfd_res=None,
):
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false')
exit(0)
img = check_img(img)
imgs = [img]
if det and rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
ocr_res.append(None)
continue
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
elif not det and rec:
ocr_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
return ocr_res
def __call__(self, img, mfd_res=None):
if img is None:
logger.debug("no valid image provided")
return None, None
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
return None, None
else:
pass
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res
if __name__ == '__main__':
pytorch_paddle_ocr = PytorchPaddleOCR()
img = cv2.imread("/Users/myhloli/Downloads/screenshot-20250326-194348.png")
dt_boxes, rec_res = pytorch_paddle_ocr(img)
ocr_res = []
if not dt_boxes and not rec_res:
ocr_res.append(None)
else:
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
print(ocr_res)
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from collections import OrderedDict
import numpy as np
import cv2
import torch
from pytorchocr.modeling.architectures.base_model import BaseModel
class BaseOCRV20:
def __init__(self, config, **kwargs):
self.config = config
self.build_net(**kwargs)
self.net.eval()
def build_net(self, **kwargs):
self.net = BaseModel(self.config, **kwargs)
def load_paddle_weights(self, weights_path):
raise NotImplementedError('implemented in converter.')
print('paddle weights loading...')
import paddle.fluid as fluid
with fluid.dygraph.guard():
para_state_dict, opti_state_dict = fluid.load_dygraph(weights_path)
for k,v in self.net.state_dict().items():
name = k
if name.endswith('num_batches_tracked'):
continue
if name.endswith('running_mean'):
ppname = name.replace('running_mean', '_mean')
elif name.endswith('running_var'):
ppname = name.replace('running_var', '_variance')
elif name.endswith('bias') or name.endswith('weight'):
ppname = name
elif 'lstm' in name:
ppname = name
else:
print('Redundance:')
print(name)
raise ValueError
try:
if ppname.endswith('fc.weight'):
self.net.state_dict()[k].copy_(torch.Tensor(para_state_dict[ppname].T))
else:
self.net.state_dict()[k].copy_(torch.Tensor(para_state_dict[ppname]))
except Exception as e:
print('pytorch: {}, {}'.format(k, v.size()))
print('paddle: {}, {}'.format(ppname, para_state_dict[ppname].shape))
raise e
print('model is loaded: {}'.format(weights_path))
def read_pytorch_weights(self, weights_path):
if not os.path.exists(weights_path):
raise FileNotFoundError('{} is not existed.'.format(weights_path))
weights = torch.load(weights_path)
return weights
def get_out_channels(self, weights):
if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
out_channels = list(weights.values())[-1].numpy().shape[1]
else:
out_channels = list(weights.values())[-1].numpy().shape[0]
return out_channels
def load_state_dict(self, weights):
self.net.load_state_dict(weights)
print('weights is loaded.')
def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path))
print('model is loaded: {}'.format(weights_path))
def save_pytorch_weights(self, weights_path):
try:
torch.save(self.net.state_dict(), weights_path, _use_new_zipfile_serialization=False)
except:
torch.save(self.net.state_dict(), weights_path) # _use_new_zipfile_serialization=False for torch>=1.6.0
print('model is saved: {}'.format(weights_path))
def print_pytorch_state_dict(self):
print('pytorch:')
for k,v in self.net.state_dict().items():
print('{}----{}'.format(k,type(v)))
def read_paddle_weights(self, weights_path):
import paddle.fluid as fluid
with fluid.dygraph.guard():
para_state_dict, opti_state_dict = fluid.load_dygraph(weights_path)
return para_state_dict, opti_state_dict
def print_paddle_state_dict(self, weights_path):
import paddle.fluid as fluid
with fluid.dygraph.guard():
para_state_dict, opti_state_dict = fluid.load_dygraph(weights_path)
print('paddle"')
for k,v in para_state_dict.items():
print('{}----{}'.format(k,type(v)))
def inference(self, inputs):
with torch.no_grad():
infer = self.net(inputs)
return infer
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import numpy as np
# import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
# from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
# import paddle.distributed as dist
from pytorchocr.data.imaug import transform, create_operators
# from pytorchocr.data.simple_dataset import SimpleDataSet
# from pytorchocr.data.lmdb_dataset import LMDBDateSet
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
# from .iaa_augment import IaaAugment
# from .make_border_map import MakeBorderMap
# from .make_shrink_map import MakeShrinkMap
# from .random_crop_data import EastRandomCropData, PSERandomCrop
# from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
# from .randaugment import RandAugment
from .operators import *
# from .label_ops import *
# from .east_process import *
# from .sast_process import *
from .gen_table_mask import *
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
\ No newline at end of file
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class DecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Resize(object):
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize(object):
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
__all__ = ['build_model']
def build_model(config, **kwargs):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config, **kwargs)
return module_class
\ No newline at end of file
import os, sys
# import torch
import torch.nn as nn
# import torch.nn.functional as F
# from pytorchocr.modeling.common import Activation
from pytorchocr.modeling.transforms import build_transform
from pytorchocr.modeling.backbones import build_backbone
from pytorchocr.modeling.necks import build_neck
from pytorchocr.modeling.heads import build_head
class BaseModel(nn.Module):
def __init__(self, config, **kwargs):
"""
the module for OCR.
args:
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
# if you make model differently, you can use transfrom in det and cls
if 'Transform' not in config or config['Transform'] is None:
self.use_transform = False
else:
self.use_transform = True
config['Transform']['in_channels'] = in_channels
self.transform = build_transform(config['Transform'])
in_channels = self.transform.out_channels
# raise NotImplementedError
# build backbone, backbone is need for del, rec and cls
if 'Backbone' not in config or config['Backbone'] is None:
self.use_backbone = False
else:
self.use_backbone = True
config["Backbone"]['in_channels'] = in_channels
self.backbone = build_backbone(config["Backbone"], model_type)
in_channels = self.backbone.out_channels
# build neck
# for rec, neck can be cnn,rnn or reshape(None)
# for det, neck can be FPN, BIFPN and so on.
# for cls, neck should be none
if 'Neck' not in config or config['Neck'] is None:
self.use_neck = False
else:
self.use_neck = True
config['Neck']['in_channels'] = in_channels
self.neck = build_neck(config['Neck'])
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
if 'Head' not in config or config['Head'] is None:
self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"], **kwargs)
self.return_all_feats = config.get("return_all_feats", False)
self._initialize_weights()
def _initialize_weights(self):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x):
y = dict()
if self.use_transform:
x = self.transform(x)
if self.use_backbone:
x = self.backbone(x)
if isinstance(x, dict):
y.update(x)
else:
y["backbone_out"] = x
final_name = "backbone_out"
if self.use_neck:
x = self.neck(x)
if isinstance(x, dict):
y.update(x)
else:
y["neck_out"] = x
final_name = "neck_out"
if self.use_head:
x = self.head(x)
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_nect' in x.keys():
y['neck_out'] = x['ctc_neck']
y['head_out'] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
if self.return_all_feats:
if self.training:
return y
elif isinstance(x, dict):
return x
else:
return {final_name: x}
else:
return x
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['build_backbone']
def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_vd', 'ResNet_SAST', 'PPLCNetV3', 'PPHGNet_small']
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
from .rec_densenet import DenseNet
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'SVTRNet', 'ViTSTR', 'DenseNet', 'PPLCNetV3', 'PPHGNet_small']
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
elif model_type == "table":
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
else:
raise NotImplementedError
module_name = config.pop('name')
assert module_name in support_dict, Exception(
'when model typs is {}, backbone only support {}'.format(model_type,
support_dict))
module_class = eval(module_name)(**config)
return module_class
\ No newline at end of file
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(
out_channels,
)
if self.if_act:
self.act = Activation(act_type=act, inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
x = self.act(x)
return x
class SEModule(nn.Module):
def __init__(self, in_channels, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.relu1 = Activation(act_type='relu', inplace=True)
self.conv2 = nn.Conv2d(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.hard_sigmoid = Activation(act_type='hard_sigmoid', inplace=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = self.relu1(outputs)
outputs = self.conv2(outputs)
outputs = self.hard_sigmoid(outputs)
outputs = inputs * outputs
return outputs
class ResidualUnit(nn.Module):
def __init__(self,
in_channels,
mid_channels,
out_channels,
kernel_size,
stride,
use_se,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = inputs + x
return x
class MobileNetV3(nn.Module):
def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, scale)
inplanes = 16
# conv1
self.conv = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
if_act=True,
act='hard_swish',
name='conv1')
self.stages = nn.ModuleList()
self.out_channels = []
block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
if s == 2 and i > 2:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act='hard_swish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
# for i, stage in enumerate(self.stages):
# self.add_sublayer(sublayer=stage, name="stage{}".format(i))
def forward(self, x):
x = self.conv(x)
out_list = []
for stage in self.stages:
x = stage(x)
out_list.append(x)
return out_list
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .det_resnet_vd import DeformableConvV2, ConvBNLayer
class BottleneckBlock(nn.Module):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
is_dcn=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
act="relu", )
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu",
is_dcn=is_dcn,
# dcn_groups=1,
)
self.conv2 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
act=None, )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
stride=stride, )
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Module):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu")
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv1)
y = F.relu(y)
return y
class ResNet(nn.Module):
def __init__(self,
in_channels=3,
layers=50,
out_indices=None,
dcn_stage=None):
super(ResNet, self).__init__()
self.layers = layers
self.input_image_channel = in_channels
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv = ConvBNLayer(
in_channels=self.input_image_channel,
out_channels=64,
kernel_size=7,
stride=2,
act="relu", )
self.pool2d_max = nn.MaxPool2d(
kernel_size=3,
stride=2,
padding=1, )
self.stages = nn.ModuleList()
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list = nn.Sequential()
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
is_dcn=is_dcn)
block_list.add_module(conv_name, bottleneck_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
self.stages.append(block_list)
else:
for block in range(len(depth)):
shortcut = False
block_list = nn.Sequential()
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut)
block_list.add_module(conv_name, basic_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(block_list)
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
import torchvision
class DeformableConvV2(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
lr_scale=1,
regularizer=None,
skip_quant=False,
dcn_bias_regularizer=None,
dcn_bias_lr_scale=2.):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2 * groups
self.mask_channel = kernel_size**2 * groups
if bias_attr:
# in FCOS-DCN head, specifically need learning_rate and regularizer
dcn_bias_attr = True
else:
# in ResNet backbone, do not need bias
dcn_bias_attr = False
self.conv_dcn = torchvision.ops.DeformConv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 * dilation,
dilation=dilation,
groups=groups//2 if groups > 1 else 1,
bias=dcn_bias_attr)
self.conv_offset = nn.Conv2d(
in_channels,
groups * 3 * kernel_size**2,
kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
bias=True)
if skip_quant:
self.conv_offset.skip_quant = True
def forward(self, x):
offset_mask = self.conv_offset(x)
offset, mask = torch.split(
offset_mask,
split_size_or_sections=[self.offset_channel, self.mask_channel],
dim=1)
mask = torch.sigmoid(mask)
y = self.conv_dcn(x, offset, mask=mask)
return y
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
dcn_groups=1,
is_vd_mode=False,
act=None,
name=None,
is_dcn=False,
):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self.act = act
self._pool2d_avg = nn.AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
if not is_dcn:
self._conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False)
else:
self._conv = DeformableConvV2(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=dcn_groups,
bias_attr=False)
self._batch_norm = nn.BatchNorm2d(
out_channels,
track_running_stats=True,
)
if act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class BottleneckBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None,
is_dcn=False,
):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b",
is_dcn=is_dcn,
dcn_groups=2,
)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = short + conv1
y = F.relu(y)
return y
class ResNet_vd(nn.Module):
def __init__(self,
in_channels=3,
layers=50,
dcn_stage=None,
out_indices=None,
**kwargs):
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.stages = nn.ModuleList()
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
# block_list = []
block_list = nn.Sequential()
shortcut = False
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
is_dcn=is_dcn,
)
shortcut = True
block_list.add_module('bb_%d_%d' % (block, i), bottleneck_block)
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
# self.stages.append(nn.Sequential(*block_list))
self.stages.append(block_list)
else:
for block in range(len(depth)):
# block_list = []
block_list = nn.Sequential()
shortcut = False
# is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name)
shortcut = True
block_list.add_module('bb_%d_%d' % (block, i), basic_block)
# block_list.append(basic_block)
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(block_list)
# self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
# import paddle
# from paddle import ParamAttr
# import paddle.nn as nn
# import paddle.nn.functional as F
__all__ = ["ResNet_SAST"]
class ConvBNLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm2d(
out_channels,)
self.act = act
if act is not None:
self._act = Activation(act_type=act)
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act:
y = self._act(y)
return y
class BottleneckBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv1)
y = F.relu(y)
return y
class ResNet_SAST(nn.Module):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet_SAST, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
# num_channels = [64, 256, 512,
# 1024] if layers >= 50 else [64, 64, 128, 256]
# num_filters = [64, 128, 256, 512]
num_channels = [64, 256, 512,
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
# self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.stages = nn.ModuleList()
self.out_channels = [3, 64]
if layers >= 50:
for block in range(len(depth)):
# block_list = []
block_list = nn.Sequential()
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = BottleneckBlock(
in_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name
)
shortcut = True
# block_list.append(bottleneck_block)
block_list.add_module('bb_%d_%d' % (block, i), bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
# self.stages.append(nn.Sequential(*block_list))
self.stages.append(block_list)
else:
for block in range(len(depth)):
# block_list = []
block_list = nn.Sequential()
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name)
shortcut = True
# block_list.append(basic_block)
block_list.add_module('bb_%d_%d' % (block, i), basic_block)
self.out_channels.append(num_filters[block])
# self.stages.append(nn.Sequential(*block_list))
self.stages.append(block_list)
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorchocr.modeling.common import Activation
__all__ = ["ResNet"]
class ConvBNLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm2d(out_channels)
self.act = act
if self.act is not None:
self._act = Activation(act_type=self.act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class BottleneckBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = torch.add(short, conv1)
y = F.relu(y)
return y
class ResNet(nn.Module):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024,
2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act='relu',
name="conv1_1")
self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.stages = nn.ModuleList()
self.out_channels = [3, 64]
# num_filters = [64, 128, 256, 512, 512]
if layers >= 50:
for block in range(len(depth)):
block_list = nn.Sequential()
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneckBlock = BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name)
shortcut = True
block_list.add_module('bb_%d_%d' % (block, i), bottleneckBlock)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(block_list)
else:
for block in range(len(depth)):
block_list = nn.Sequential()
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basicBlock = BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name)
shortcut = True
block_list.add_module('bb_%d_%d' % (block, i), basicBlock)
self.out_channels.append(num_filters[block])
self.stages.append(block_list)
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/densenet.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module):
def __init__(self, nChannels, growthRate, use_dropout):
super(Bottleneck, self).__init__()
interChannels = 4 * growthRate
self.bn1 = nn.BatchNorm2d(interChannels)
self.conv1 = nn.Conv2d(
nChannels, interChannels, kernel_size=1,
bias=True) # Xavier initialization
self.bn2 = nn.BatchNorm2d(growthRate)
self.conv2 = nn.Conv2d(
interChannels, growthRate, kernel_size=3, padding=1,
bias=True) # Xavier initialization
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
if self.use_dropout:
out = self.dropout(out)
out = F.relu(self.bn2(self.conv2(out)))
if self.use_dropout:
out = self.dropout(out)
out = torch.cat([x, out], 1)
return out
class SingleLayer(nn.Module):
def __init__(self, nChannels, growthRate, use_dropout):
super(SingleLayer, self).__init__()
self.bn1 = nn.BatchNorm2d(nChannels)
self.conv1 = nn.Conv2d(
nChannels, growthRate, kernel_size=3, padding=1, bias=False)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = self.conv1(F.relu(x))
if self.use_dropout:
out = self.dropout(out)
out = torch.cat([x, out], 1)
return out
class Transition(nn.Module):
def __init__(self, nChannels, out_channels, use_dropout):
super(Transition, self).__init__()
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(
nChannels, out_channels, kernel_size=1, bias=False)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(p=0.2)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
if self.use_dropout:
out = self.dropout(out)
out = F.avg_pool2d(out, 2, ceil_mode=True, count_include_pad=False)
return out
class DenseNet(nn.Module):
def __init__(self, growthRate, reduction, bottleneck, use_dropout,
input_channel, **kwargs):
super(DenseNet, self).__init__()
nDenseBlocks = 16
nChannels = 2 * growthRate
self.conv1 = nn.Conv2d(
input_channel,
nChannels,
kernel_size=7,
padding=3,
stride=2,
bias=False)
self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans1 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
nChannels += nDenseBlocks * growthRate
out_channels = int(math.floor(nChannels * reduction))
self.trans2 = Transition(nChannels, out_channels, use_dropout)
nChannels = out_channels
self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks,
bottleneck, use_dropout)
self.out_channels = out_channels
def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck,
use_dropout):
layers = []
for i in range(int(nDenseBlocks)):
if bottleneck:
layers.append(Bottleneck(nChannels, growthRate, use_dropout))
else:
layers.append(SingleLayer(nChannels, growthRate, use_dropout))
nChannels += growthRate
return nn.Sequential(*layers)
def forward(self, inputs):
x, x_m, y = inputs
out = self.conv1(x)
out = F.relu(out, inplace=True)
out = F.max_pool2d(out, 2, ceil_mode=True)
out = self.dense1(out)
out = self.trans1(out)
out = self.dense2(out)
out = self.trans2(out)
out = self.dense3(out)
return out, x_m, y
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBNAct(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
use_act=True):
super().__init__()
self.use_act = use_act
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding=(kernel_size - 1) // 2,
groups=groups,
bias=False)
self.bn = nn.BatchNorm2d(out_channels)
if self.use_act:
self.act = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
return x
class ESEModule(nn.Module):
def __init__(self, channels):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv(x)
x = self.sigmoid(x)
return x * identity
class HG_Block(nn.Module):
def __init__(
self,
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False, ):
super().__init__()
self.identity = identity
self.layers = nn.ModuleList()
self.layers.append(
ConvBNAct(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
for _ in range(layer_num - 1):
self.layers.append(
ConvBNAct(
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
stride=1))
# feature aggregation
total_channels = in_channels + layer_num * mid_channels
self.aggregation_conv = ConvBNAct(
in_channels=total_channels,
out_channels=out_channels,
kernel_size=1,
stride=1)
self.att = ESEModule(out_channels)
def forward(self, x):
identity = x
output = []
output.append(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
x = self.aggregation_conv(x)
x = self.att(x)
if self.identity:
x += identity
return x
class HG_Stage(nn.Module):
def __init__(self,
in_channels,
mid_channels,
out_channels,
block_num,
layer_num,
downsample=True,
stride=[2, 1]):
super().__init__()
self.downsample = downsample
if downsample:
self.downsample = ConvBNAct(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=stride,
groups=in_channels,
use_act=False)
blocks_list = []
blocks_list.append(
HG_Block(
in_channels,
mid_channels,
out_channels,
layer_num,
identity=False))
for _ in range(block_num - 1):
blocks_list.append(
HG_Block(
out_channels,
mid_channels,
out_channels,
layer_num,
identity=True))
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x):
if self.downsample:
x = self.downsample(x)
x = self.blocks(x)
return x
class PPHGNet(nn.Module):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def __init__(
self,
stem_channels,
stage_config,
layer_num,
in_channels=3,
det=False,
out_indices=None):
super().__init__()
self.det = det
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
# stem
stem_channels.insert(0, in_channels)
self.stem = nn.Sequential(* [
ConvBNAct(
in_channels=stem_channels[i],
out_channels=stem_channels[i + 1],
kernel_size=3,
stride=2 if i == 0 else 1) for i in range(
len(stem_channels) - 1)
])
if self.det:
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# stages
self.stages = nn.ModuleList()
self.out_channels = []
for block_id, k in enumerate(stage_config):
in_channels, mid_channels, out_channels, block_num, downsample, stride = stage_config[
k]
self.stages.append(
HG_Stage(in_channels, mid_channels, out_channels, block_num,
layer_num, downsample, stride))
if block_id in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
if self.det:
x = self.pool(x)
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
def PPHGNet_tiny(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNet_tiny
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [96, 96, 224, 1, False, [2, 1]],
"stage2": [224, 128, 448, 1, True, [1, 2]],
"stage3": [448, 160, 512, 2, True, [2, 1]],
"stage4": [512, 192, 768, 1, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[48, 48, 96],
stage_config=stage_config,
layer_num=5,
**kwargs)
return model
def PPHGNet_small(pretrained=False, use_ssld=False, det=False, **kwargs):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config_det = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, False, 2],
"stage2": [256, 160, 512, 1, True, 2],
"stage3": [512, 192, 768, 2, True, 2],
"stage4": [768, 224, 1024, 1, True, 2],
}
stage_config_rec = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [128, 128, 256, 1, True, [2, 1]],
"stage2": [256, 160, 512, 1, True, [1, 2]],
"stage3": [512, 192, 768, 2, True, [2, 1]],
"stage4": [768, 224, 1024, 1, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[64, 64, 128],
stage_config=stage_config_det if det else stage_config_rec,
layer_num=6,
det=det,
**kwargs)
return model
def PPHGNet_base(pretrained=False, use_ssld=True, **kwargs):
"""
PPHGNet_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_base` model depends on args.
"""
stage_config = {
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1": [160, 192, 320, 1, False, [2, 1]],
"stage2": [320, 224, 640, 2, True, [1, 2]],
"stage3": [640, 256, 960, 3, True, [2, 1]],
"stage4": [960, 288, 1280, 2, True, [2, 1]],
}
model = PPHGNet(
stem_channels=[96, 96, 160],
stage_config=stage_config,
layer_num=7,
**kwargs)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment