"examples/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "274330505ebb9329418bf7c5f2bb4d43bb803113"
Unverified Commit c402c010 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2072 from opendatalab/dev

refactor(ocr): remove redundant code and improve code quality 
parents 01d1e086 bb30f32e
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import copy import copy
import os.path import os.path
import warnings
from pathlib import Path from pathlib import Path
import cv2 import cv2
...@@ -92,45 +93,46 @@ class PytorchPaddleOCR(TextSystem): ...@@ -92,45 +93,46 @@ class PytorchPaddleOCR(TextSystem):
exit(0) exit(0)
img = check_img(img) img = check_img(img)
imgs = [img] imgs = [img]
with warnings.catch_warnings():
if det and rec: warnings.simplefilter("ignore", category=RuntimeWarning)
ocr_res = [] if det and rec:
for img in imgs: ocr_res = []
img = preprocess_image(img) for img in imgs:
dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res) img = preprocess_image(img)
if not dt_boxes and not rec_res: dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
ocr_res.append(None) if not dt_boxes and not rec_res:
continue ocr_res.append(None)
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] continue
ocr_res.append(tmp_res) tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
return ocr_res ocr_res.append(tmp_res)
elif det and not rec: return ocr_res
ocr_res = [] elif det and not rec:
for img in imgs: ocr_res = []
img = preprocess_image(img) for img in imgs:
dt_boxes, elapse = self.text_detector(img)
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None:
ocr_res.append(None)
continue
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
elif not det and rec:
ocr_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img) img = preprocess_image(img)
img = [img] dt_boxes, elapse = self.text_detector(img)
rec_res, elapse = self.text_recognizer(img) # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse)) if dt_boxes is None:
ocr_res.append(rec_res) ocr_res.append(None)
return ocr_res continue
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
elif not det and rec:
ocr_res = []
for img in imgs:
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
rec_res, elapse = self.text_recognizer(img)
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res)
return ocr_res
def __call__(self, img, mfd_res=None): def __call__(self, img, mfd_res=None):
......
...@@ -371,12 +371,6 @@ class TextRecognizer(BaseOCRV20): ...@@ -371,12 +371,6 @@ class TextRecognizer(BaseOCRV20):
gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list) gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list) gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
# if self.use_gpu:
# inp = inp.cuda()
# encoder_word_pos_inp = encoder_word_pos_inp.cuda()
# gsrm_word_pos_inp = gsrm_word_pos_inp.cuda()
# gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.cuda()
# gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.cuda()
inp = inp.to(self.device) inp = inp.to(self.device)
encoder_word_pos_inp = encoder_word_pos_inp.to(self.device) encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device) gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
...@@ -398,8 +392,6 @@ class TextRecognizer(BaseOCRV20): ...@@ -398,8 +392,6 @@ class TextRecognizer(BaseOCRV20):
with torch.no_grad(): with torch.no_grad():
inp = torch.from_numpy(norm_img_batch) inp = torch.from_numpy(norm_img_batch)
# if self.use_gpu:
# inp = inp.cuda()
inp = inp.to(self.device) inp = inp.to(self.device)
preds = self.net(inp) preds = self.net(inp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment