Commit 80aced81 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents fce82425 896d149e
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
"\n", "\n",
"## 1 背景介绍\n", "## 1 背景介绍\n",
"\n", "\n",
"文本识别是OCR(Optical Character Recognition)的一个子任务,其任务为识别一个固定区域的文本内容。在OCR的两阶段方法里,它接在文本检测后面,将图像信息转换为文字信息。\n", "文本识别是OCR(Optical Character Recognition)的一个子任务,其任务为识别一个固定区域的文本内容。在OCR的两阶段方法里,它接在文本检测后面,将图像信息转换为文字信息。\n",
"\n", "\n",
"具体地,模型输入一张定位好的文本行,由模型预测出图片中的文字内容和置信度,可视化结果如下图所示:\n", "具体地,模型输入一张定位好的文本行,由模型预测出图片中的文字内容和置信度,可视化结果如下图所示:\n",
"\n", "\n",
......
...@@ -2915,7 +2915,7 @@ ...@@ -2915,7 +2915,7 @@
"\n", "\n",
"```yaml\n", "```yaml\n",
"Architecture:\n", "Architecture:\n",
" model_type: &model_type \"rec\" # 模型类别,rec、det等,每个子网络的模型类别都与\n", " model_type: &model_type \"rec\" # 模型类别,rec、det等,每个子网络的模型类别都与\n",
" name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构\n", " name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构\n",
" algorithm: Distillation # 算法名称\n", " algorithm: Distillation # 算法名称\n",
" Models: # 模型,包含子网络的配置信息\n", " Models: # 模型,包含子网络的配置信息\n",
...@@ -2915,7 +2915,7 @@ ...@@ -2915,7 +2915,7 @@
"\n", "\n",
"```yaml\n", "```yaml\n",
"Architecture:\n", "Architecture:\n",
" model_type: &model_type \"rec\" # 模型类别,rec、det等,每个子网络的模型类别都与\n", " model_type: &model_type \"rec\" # 模型类别,rec、det等,每个子网络的模型类别都与\n",
" name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构\n", " name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构\n",
" algorithm: Distillation # 算法名称\n", " algorithm: Distillation # 算法名称\n",
" Models: # 模型,包含子网络的配置信息\n", " Models: # 模型,包含子网络的配置信息\n",
...@@ -1876,11 +1876,11 @@ ...@@ -1876,11 +1876,11 @@
" rec_res)\n", " rec_res)\n",
" filter_boxes, filter_rec_res = [], []\n", " filter_boxes, filter_rec_res = [], []\n",
" # 根据识别得分的阈值对结果进行过滤,如果得分小于阈值,就过滤掉\n", " # 根据识别得分的阈值对结果进行过滤,如果得分小于阈值,就过滤掉\n",
" for box, rec_reuslt in zip(dt_boxes, rec_res):\n", " for box, rec_result in zip(dt_boxes, rec_res):\n",
" text, score = rec_reuslt\n", " text, score = rec_result\n",
" if score >= self.drop_score:\n", " if score >= self.drop_score:\n",
" filter_boxes.append(box)\n", " filter_boxes.append(box)\n",
" filter_rec_res.append(rec_reuslt)\n", " filter_rec_res.append(rec_result)\n",
" return filter_boxes, filter_rec_res\n", " return filter_boxes, filter_rec_res\n",
"\n", "\n",
"def sorted_boxes(dt_boxes):\n", "def sorted_boxes(dt_boxes):\n",
...@@ -327,7 +327,7 @@ ...@@ -327,7 +327,7 @@
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/899470ba601349fbbc402a4c83e6cdaee08aaa10b5004977b1f684f346ebe31f\" width=\"800\"/></center>\n", "<img src=\"https://ai-studio-static-online.cdn.bcebos.com/899470ba601349fbbc402a4c83e6cdaee08aaa10b5004977b1f684f346ebe31f\" width=\"800\"/></center>\n",
"<center>图 18: SER,RE任务示例</center>\n", "<center>图 18: SER,RE任务示例</center>\n",
"\n", "\n",
"一般的KIE方法基于命名实体识别(Named Entity Recognition,NER)[4]来研究,但是这类方法只利用了图像中的文本信息,缺少对视觉和结构信息的使用,因此精度不高。在此基础上,近几年的方法都开始将视觉和结构信息与文本信息融合到一起,按照对多模态信息进行融合时所采用的原理可以将这些方法分为下面三种:\n", "一般的KIE方法基于命名实体识别(Named Entity Recognition,NER)[4]来研究,但是这类方法只利用了图像中的文本信息,缺少对视觉和结构信息的使用,因此精度不高。在此基础上,近几年的方法都开始将视觉和结构信息与文本信息融合到一起,按照对多模态信息进行融合时所采用的原理可以将这些方法分为下面三种:\n",
"\n", "\n",
"1. 基于Grid的方法\n", "1. 基于Grid的方法\n",
"1. 基于Token的方法\n", "1. 基于Token的方法\n",
......
...@@ -136,7 +136,7 @@ ...@@ -136,7 +136,7 @@
"<br><center>Figure 11: LOMO frame diagram</center>\n", "<br><center>Figure 11: LOMO frame diagram</center>\n",
"\n", "\n",
"\n", "\n",
"Contournet [18] is based on the proposed modeling of text contour points to obtain a curved text detection frame. This method first uses Adaptive-RPN to obtain the proposal features of the text area, and then designs a local orthogonal texture perception LOTM module to learn horizontal and vertical textures. The feature is represented by contour points. Finally, by considering the feature responses in two orthogonal directions at the same time, the Point Re-Scoring algorithm can effectively filter out the prediction of strong unidirectional or weak orthogonal activation, and the final text contour can be used as a A group of high-quality contour points are shown.\n", "Contournet [18] is based on the proposed modeling of text contour points to obtain a curved text detection frame. This method first uses Adaptive-RPN to obtain the proposal features of the text area, and then designs a local orthogonal texture perception LOTM module to learn horizontal and vertical textures. The feature is represented by contour points. Finally, by considering the feature responses in two orthogonal directions at the same time, the Point Re-Scoring algorithm can effectively filter out the prediction of strong unidirectional or weak orthogonal activation, and the final text contour can be used as a group of high-quality contour points are shown.\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/1f59ab5db899412f8c70ba71e8dd31d4ea9480d6511f498ea492c97dd2152384\"\n", "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/1f59ab5db899412f8c70ba71e8dd31d4ea9480d6511f498ea492c97dd2152384\"\n",
"width=\"600\" ></center>\n", "width=\"600\" ></center>\n",
"<br><center>Figure 12: Contournet frame diagram</center>\n", "<br><center>Figure 12: Contournet frame diagram</center>\n",
......
...@@ -1886,11 +1886,11 @@ ...@@ -1886,11 +1886,11 @@
" rec_res)\n", " rec_res)\n",
" filter_boxes, filter_rec_res = [], []\n", " filter_boxes, filter_rec_res = [], []\n",
" #Filter the results according to the threshold of the recognition score, if the score is less than the threshold, filter out\n", " #Filter the results according to the threshold of the recognition score, if the score is less than the threshold, filter out\n",
" for box, rec_reuslt in zip(dt_boxes, rec_res):\n", " for box, rec_result in zip(dt_boxes, rec_res):\n",
" text, score = rec_reuslt\n", " text, score = rec_result\n",
" if score >= self.drop_score:\n", " if score >= self.drop_score:\n",
" filter_boxes.append(box)\n", " filter_boxes.append(box)\n",
" filter_rec_res.append(rec_reuslt)\n", " filter_rec_res.append(rec_result)\n",
" return filter_boxes, filter_rec_res\n", " return filter_boxes, filter_rec_res\n",
"\n", "\n",
"def sorted_boxes(dt_boxes):\n", "def sorted_boxes(dt_boxes):\n",
...@@ -47,16 +47,46 @@ __all__ = [ ...@@ -47,16 +47,46 @@ __all__ = [
] ]
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.5' VERSION = '2.5.0.1'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' DEFAULT_OCR_MODEL_VERSION = 'PP-OCRv3'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2'] SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2', 'PP-OCRv3']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' DEFAULT_STRUCTURE_MODEL_VERSION = 'PP-STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE'] SUPPORT_STRUCTURE_MODEL_VERSION = ['PP-STRUCTURE']
MODEL_URLS = { MODEL_URLS = {
'OCR': { 'OCR': {
'PP-OCRv3': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar',
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar',
},
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
},
'cls': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
},
},
'PP-OCRv2': { 'PP-OCRv2': {
'det': { 'det': {
'ch': { 'ch': {
...@@ -72,7 +102,7 @@ MODEL_URLS = { ...@@ -72,7 +102,7 @@ MODEL_URLS = {
} }
} }
}, },
DEFAULT_OCR_MODEL_VERSION: { 'PP-OCR': {
'det': { 'det': {
'ch': { 'ch': {
'url': 'url':
...@@ -173,7 +203,7 @@ MODEL_URLS = { ...@@ -173,7 +203,7 @@ MODEL_URLS = {
} }
}, },
'STRUCTURE': { 'STRUCTURE': {
DEFAULT_STRUCTURE_MODEL_VERSION: { 'PP-STRUCTURE': {
'table': { 'table': {
'en': { 'en': {
'url': 'url':
...@@ -198,16 +228,17 @@ def parse_args(mMain=True): ...@@ -198,16 +228,17 @@ def parse_args(mMain=True):
"--ocr_version", "--ocr_version",
type=str, type=str,
choices=SUPPORT_OCR_MODEL_VERSION, choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2', default='PP-OCRv3',
help='OCR Model version, the current model support list is as follows: ' help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. ' '1. PP-OCRv3 Support Chinese and English detection and recognition model, and direction classifier model'
'2. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.' '2. PP-OCRv2 Support Chinese detection and recognition model. '
'3. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.'
) )
parser.add_argument( parser.add_argument(
"--structure_version", "--structure_version",
type=str, type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION, choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE', default='PP-STRUCTURE',
help='Model version, the current model support list is as follows:' help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.') ' 1. STRUCTURE Support en table structure model.')
......
...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
......
...@@ -207,25 +207,6 @@ class PRENResizeImg(object): ...@@ -207,25 +207,6 @@ class PRENResizeImg(object):
return data return data
class SVTRRecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape): ...@@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32) return np.reshape(img_black, (c, row, col)).astype(np.float32)
def resize_norm_img_svtr(img, image_shape, padding=False):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
if h > 2.0 * w:
image = Image.fromarray(img)
image1 = image.rotate(90, expand=True)
image2 = image.rotate(-90, expand=True)
img1 = np.array(image1)
img2 = np.array(image2)
else:
img1 = copy.deepcopy(img)
img2 = copy.deepcopy(img)
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image1 = cv2.resize(
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image2 = cv2.resize(
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image1 = resized_image1.astype('float32')
resized_image2 = resized_image2.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resized_image1 -= 0.5
resized_image1 /= 0.5
resized_image2 -= 0.5
resized_image2 /= 0.5
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
padding_im[0, :, :, 0:resized_w] = resized_image
padding_im[1, :, :, 0:resized_w] = resized_image1
padding_im[2, :, :, 0:resized_w] = resized_image2
return padding_im
def srn_other_inputs(image_shape, num_heads, max_text_length): def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
......
...@@ -64,9 +64,9 @@ class DetMetric(object): ...@@ -64,9 +64,9 @@ class DetMetric(object):
} }
""" """
metircs = self.evaluator.combine_results(self.results) metrics = self.evaluator.combine_results(self.results)
self.reset() self.reset()
return metircs return metrics
def reset(self): def reset(self):
self.results = [] # clear results self.results = [] # clear results
...@@ -127,20 +127,20 @@ class DetFCEMetric(object): ...@@ -127,20 +127,20 @@ class DetFCEMetric(object):
'thr 0.9':'precision: 0 recall: 0 hmean: 0', 'thr 0.9':'precision: 0 recall: 0 hmean: 0',
} }
""" """
metircs = {} metrics = {}
hmean = 0 hmean = 0
for score_thr in self.results.keys(): for score_thr in self.results.keys():
metirc = self.evaluator.combine_results(self.results[score_thr]) metric = self.evaluator.combine_results(self.results[score_thr])
# for key, value in metirc.items(): # for key, value in metric.items():
# metircs['{}_{}'.format(key, score_thr)] = value # metrics['{}_{}'.format(key, score_thr)] = value
metirc_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format( metric_str = 'precision:{:.5f} recall:{:.5f} hmean:{:.5f}'.format(
metirc['precision'], metirc['recall'], metirc['hmean']) metric['precision'], metric['recall'], metric['hmean'])
metircs['thr {}'.format(score_thr)] = metirc_str metrics['thr {}'.format(score_thr)] = metric_str
hmean = max(hmean, metirc['hmean']) hmean = max(hmean, metric['hmean'])
metircs['hmean'] = hmean metrics['hmean'] = hmean
self.reset() self.reset()
return metircs return metrics
def reset(self): def reset(self):
self.results = { self.results = {
......
...@@ -78,9 +78,9 @@ class E2EMetric(object): ...@@ -78,9 +78,9 @@ class E2EMetric(object):
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
metircs = combine_results(self.results) metrics = combine_results(self.results)
self.reset() self.reset()
return metircs return metrics
def reset(self): def reset(self):
self.results = [] # clear results self.results = [] # clear results
...@@ -61,9 +61,9 @@ class KIEMetric(object): ...@@ -61,9 +61,9 @@ class KIEMetric(object):
def get_metric(self): def get_metric(self):
metircs = self.combine_results(self.results) metrics = self.combine_results(self.results)
self.reset() self.reset()
return metircs return metrics
def reset(self): def reset(self):
self.results = [] # clear results self.results = [] # clear results
......
...@@ -34,13 +34,13 @@ class VQASerTokenMetric(object): ...@@ -34,13 +34,13 @@ class VQASerTokenMetric(object):
def get_metric(self): def get_metric(self):
from seqeval.metrics import f1_score, precision_score, recall_score from seqeval.metrics import f1_score, precision_score, recall_score
metircs = { metrics = {
"precision": precision_score(self.gt_list, self.pred_list), "precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list), "recall": recall_score(self.gt_list, self.pred_list),
"hmean": f1_score(self.gt_list, self.pred_list), "hmean": f1_score(self.gt_list, self.pred_list),
} }
self.reset() self.reset()
return metircs return metrics
def reset(self): def reset(self):
self.pred_list = [] self.pred_list = []
......
...@@ -92,6 +92,9 @@ class BaseModel(nn.Layer): ...@@ -92,6 +92,9 @@ class BaseModel(nn.Layer):
else: else:
y["head_out"] = x y["head_out"] = x
if self.return_all_feats: if self.return_all_feats:
return y if self.training:
return y
else:
return {"head_out": y["head_out"]}
else: else:
return x return x
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import Callable
from paddle import ParamAttr from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal from paddle.nn.initializer import KaimingNormal
import numpy as np import numpy as np
...@@ -170,17 +169,14 @@ class Attention(nn.Layer): ...@@ -170,17 +169,14 @@ class Attention(nn.Layer):
self.N = H * W self.N = H * W
self.C = dim self.C = dim
if mixer == 'Local' and HW is not None: if mixer == 'Local' and HW is not None:
hk = local_k[0] hk = local_k[0]
wk = local_k[1] wk = local_k[1]
mask = np.ones([H * W, H * W]) mask = paddle.ones([H * W, H + hk - 1, W + wk - 1], dtype='float32')
for h in range(H): for h in range(0, H):
for w in range(W): for w in range(0, W):
for kh in range(-(hk // 2), (hk // 2) + 1): mask[h * W + w, h:h + hk, w:w + wk] = 0.
for kw in range(-(wk // 2), (wk // 2) + 1): mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
if H > (h + kh) >= 0 and W > (w + kw) >= 0: 2].flatten(1)
mask[h * W + w][(h + kh) * W + (w + kw)] = 0
mask_paddle = paddle.to_tensor(mask, dtype='float32')
mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32') mask_inf = paddle.full([H * W, H * W], '-inf', dtype='float32')
mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf) mask = paddle.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask.unsqueeze([0, 1]) self.mask = mask.unsqueeze([0, 1])
...@@ -228,11 +224,8 @@ class Block(nn.Layer): ...@@ -228,11 +224,8 @@ class Block(nn.Layer):
super().__init__() super().__init__()
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm1 = norm_layer(dim)
else: else:
raise TypeError( self.norm1 = norm_layer(dim)
"The norm_layer must be str or paddle.nn.layer.Layer class")
if mixer == 'Global' or mixer == 'Local': if mixer == 'Global' or mixer == 'Local':
self.mixer = Attention( self.mixer = Attention(
dim, dim,
...@@ -250,15 +243,11 @@ class Block(nn.Layer): ...@@ -250,15 +243,11 @@ class Block(nn.Layer):
else: else:
raise TypeError("The mixer must be one of [Global, Local, Conv]") raise TypeError("The mixer must be one of [Global, Local, Conv]")
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
elif isinstance(norm_layer, Callable):
self.norm2 = norm_layer(dim)
else: else:
raise TypeError( self.norm2 = norm_layer(dim)
"The norm_layer must be str or paddle.nn.layer.Layer class")
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim, self.mlp = Mlp(in_features=dim,
...@@ -330,8 +319,6 @@ class PatchEmbed(nn.Layer): ...@@ -330,8 +319,6 @@ class PatchEmbed(nn.Layer):
act=nn.GELU, act=nn.GELU,
bias_attr=None), bias_attr=None),
ConvBNLayer( ConvBNLayer(
embed_dim // 2,
embed_dim,
in_channels=embed_dim // 2, in_channels=embed_dim // 2,
out_channels=embed_dim, out_channels=embed_dim,
kernel_size=3, kernel_size=3,
......
...@@ -128,8 +128,6 @@ class STN_ON(nn.Layer): ...@@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
self.out_channels = in_channels self.out_channels = in_channels
def forward(self, image): def forward(self, image):
if len(image.shape)==5:
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
stn_input = paddle.nn.functional.interpolate( stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
......
...@@ -43,12 +43,15 @@ class Momentum(object): ...@@ -43,12 +43,15 @@ class Momentum(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -76,6 +79,9 @@ class Adam(object): ...@@ -76,6 +79,9 @@ class Adam(object):
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adam( opt = optim.Adam(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -85,7 +91,7 @@ class Adam(object): ...@@ -85,7 +91,7 @@ class Adam(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
lazy_mode=self.lazy_mode, lazy_mode=self.lazy_mode,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -118,6 +124,9 @@ class RMSProp(object): ...@@ -118,6 +124,9 @@ class RMSProp(object):
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -125,7 +134,7 @@ class RMSProp(object): ...@@ -125,7 +134,7 @@ class RMSProp(object):
epsilon=self.epsilon, epsilon=self.epsilon,
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -149,6 +158,9 @@ class Adadelta(object): ...@@ -149,6 +158,9 @@ class Adadelta(object):
self.name = name self.name = name
def __call__(self, model): def __call__(self, model):
train_params = [
param for param in model.parameters() if param.trainable is True
]
opt = optim.Adadelta( opt = optim.Adadelta(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
epsilon=self.epsilon, epsilon=self.epsilon,
...@@ -156,7 +168,7 @@ class Adadelta(object): ...@@ -156,7 +168,7 @@ class Adadelta(object):
weight_decay=self.weight_decay, weight_decay=self.weight_decay,
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
name=self.name, name=self.name,
parameters=model.parameters()) parameters=train_params)
return opt return opt
...@@ -190,17 +202,20 @@ class AdamW(object): ...@@ -190,17 +202,20 @@ class AdamW(object):
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model): def __call__(self, model):
parameters = model.parameters() parameters = [
param for param in model.parameters() if param.trainable is True
]
self.no_weight_decay_param_name_list = [ self.no_weight_decay_param_name_list = [
p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) p.name for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
] ]
if self.one_dim_param_no_weight_decay: if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [ self.no_weight_decay_param_name_list += [
p.name for n, p in model.named_parameters() if len(p.shape) == 1 p.name for n, p in model.named_parameters() if len(p.shape) == 1
] ]
opt = optim.AdamW( opt = optim.AdamW(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -216,4 +231,4 @@ class AdamW(object): ...@@ -216,4 +231,4 @@ class AdamW(object):
return opt return opt
def _apply_decay_param_fun(self, name): def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list return name not in self.no_weight_decay_param_name_list
\ No newline at end of file
...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess ...@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode SEEDLabelDecode, PRENLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'SVTRLabelDecode' 'DistillationSARLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode): ...@@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
return text return text
label = self.decode(label) label = self.decode(label)
return text, label return text, label
class SVTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SVTRLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=-1)
preds_prob = preds.max(axis=-1)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
return_text = []
for i in range(0, len(text), 3):
text0 = text[i]
text1 = text[i + 1]
text2 = text[i + 2]
text_pred = [text0[0], text1[0], text2[0]]
text_prob = [text0[1], text1[1], text2[1]]
id_max = text_prob.index(max(text_prob))
return_text.append((text_pred[id_max], text_prob[id_max]))
if label is None:
return return_text
label = self.decode(label)
return return_text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment