"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "62e00968177b2a185f2dc6c2015fdf74e0f038eb"
Unverified Commit aa1d40b1 authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into bm_dyg

parents b796c916 2c2918f2
include LICENSE.txt
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
......
......@@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity {
}
public void btn_load_model_click(View view) {
tvStatus.setText("STATUS: load model ......");
loadModel();
if (predictor.isLoaded()){
tvStatus.setText("STATUS: model has been loaded");
}else{
tvStatus.setText("STATUS: load model ......");
loadModel();
}
}
public void btn_run_model_click(View view) {
......
......@@ -194,26 +194,25 @@ public class Predictor {
"supported!");
return false;
}
int[] channelStride = new int[]{width * height, width * height * 2};
int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int color = scaleImage.getPixel(x, y);
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
(float) blue(color) / 255.0f};
inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
}
int[] channelStride = new int[]{width * height, width * height * 2};
int[] pixels=new int[width*height];
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
for (int i = 0; i < pixels.length; i++) {
int color = pixels[i];
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
(float) blue(color) / 255.0f};
inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
}
} else if (channels == 1) {
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int color = inputImage.getPixel(x, y);
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0];
}
int[] pixels=new int[width*height];
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
for (int i = 0; i < pixels.length; i++) {
int color = pixels[i];
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
inputData[i] = (gray - inputMean[0]) / inputStd[0];
}
} else {
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
......
......@@ -111,9 +111,9 @@
| 字段 | 用途 | 默认值 | 备注 |
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
| **dataset** | 每次迭代返回一个样本 | - | - |
| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet``LMDBDateSet` |
| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet``LMDBDataSet` |
| data_dir | 数据集图片存放路径 | ./train_data | \ |
| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDateSet时不需要此参数 |
| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDataSet时不需要此参数 |
| ratio_list | 数据集的比例 | [1.0] | 若label_file_list中有两个train_list,且ratio_list为[0.4,0.6],则从train_list1中采样40%,从train_list2中采样60%组合整个dataset |
| transforms | 对图片和标签进行变换的方法列表 | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | 见[ppocr/data/imaug](../../ppocr/data/imaug) |
| **loader** | dataloader相关 | - | |
......
......@@ -243,7 +243,7 @@ Optimizer:
Train:
dataset:
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
name: SimpleDataSet
# 数据集路径
data_dir: ./train_data/
......@@ -263,7 +263,7 @@ Train:
Eval:
dataset:
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
name: SimpleDataSet
# 数据集路径
data_dir: ./train_data
......@@ -393,7 +393,7 @@ Global:
Train:
dataset:
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
name: SimpleDataSet
# 数据集路径
data_dir: ./train_data/
......@@ -403,7 +403,7 @@ Train:
Eval:
dataset:
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
name: SimpleDataSet
# 数据集路径
data_dir: ./train_data
......
......@@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and
| Parameter | Use | Defaults | Note |
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
| **dataset** | Return one sample per iteration | - | - |
| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDateSet` |
| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDataSet` |
| data_dir | Image folder path | ./train_data | \ |
| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDateSet |
| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDataSet |
| ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset |
| transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) |
| **loader** | dataloader related | - | |
......
......@@ -237,7 +237,7 @@ Optimizer:
Train:
dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet
# Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet
# Path of dataset
data_dir: ./train_data/
......@@ -257,7 +257,7 @@ Train:
Eval:
dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet
# Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet
# Path of dataset
data_dir: ./train_data
......@@ -394,7 +394,7 @@ Global:
Train:
dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet
# Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet
# Path of dataset
data_dir: ./train_data/
......@@ -404,7 +404,7 @@ Train:
Eval:
dataset:
# Type of dataset,we support LMDBDateSet and SimpleDataSet
# Type of dataset,we support LMDBDataSet and SimpleDataSet
name: SimpleDataSet
# Path of dataset
data_dir: ./train_data
......
doc/joinus.PNG

102 KB | W: | H:

doc/joinus.PNG

78.1 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
......@@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import logging
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system
from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR']
......@@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
model_urls = {
'det': {
'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
}
},
'cls':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
}
SUPPORT_DET_MODEL = ['DB']
......@@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True):
import argparse
parser = init_args()
......@@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
......@@ -223,46 +180,46 @@ class PaddleOCR(predict_system.TextSystem):
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False
if postprocess_params.rec_char_dict_path is None:
if params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir
if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
if params.cls_model_dir is None:
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
# download model
maybe_download(postprocess_params.det_model_dir,
maybe_download(params.det_model_dir,
model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir,
maybe_download(params.rec_model_dir,
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
maybe_download(params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
params.rec_char_dict_path = str(
Path(__file__).parent / params.rec_char_dict_path)
print(params)
# init det_model and rec_model
super().__init__(postprocess_params)
super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True):
"""
......
......@@ -81,7 +81,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
img.astype('float32') * self.scale - self.mean) / self.std
return data
......@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
......@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
else:
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
......@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
......
......@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
self.character_str = []
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str += line
self.character_str.append(line)
if use_space_char:
self.character_str += " "
self.character_str.append(" ")
dict_character = list(self.character_str)
else:
......@@ -288,3 +288,156 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class TableLabelDecode(object):
""" """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_sp_tokens(self):
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num])
return sp_tokens
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
</overline>
α

$
ω
ψ
χ
(
υ
σ
,
ρ
ε
0
4
8
b
<
Ψ
Ω
D
3
Π
H
</strike>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
̌
«
#
</b>
'
Ι
+
/
·
7
;
?
C
÷
G
K
<sup>
O
S
С
W
Α
[
_
c
z
g
<i>
o
<sub>
s
w
φ
ʹ
{
»
̆
e
ˆ
τ
ι
Ø
ß
×
˃
˂
"
i
&
π
*
æ
.
ø
Q
6
:
>
a
B
F
J
̄
N
R
V
<overline>
Z
^
¤
¥
§
<underline>
¢
£
­
Λ
©
n
r
°
±
v
<b>
k
~
̇
@
ł
®
!
</sup>
%
)
-
1
5
9
=
А
A
Σ
E
I
M
m
̨
</i>
U
Y
]
̸
2
̂
̀
́
̊
̈
q
u
ı
y
</underline>
̃
}
ν
This diff is collapsed.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tarfile
import requests
from tqdm import tqdm
from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def is_link(s):
return s is not None and s.startswith('http')
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
url = default_url
if model_dir is None or is_link(model_dir):
if is_link(model_dir):
url = model_dir
file_name = url.split('/')[-1][:-4]
model_dir = default_model_dir
model_dir = os.path.join(model_dir, file_name)
return model_dir, url
include LICENSE
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include ppstructure *.py
# TableStructurer
1. 代码使用
```python
import cv2
from paddlestructure import PaddleStructure,draw_result
table_engine = PaddleStructure(
output='./output/table',
show_log=True)
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
for line in result:
print(line)
from PIL import Image
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
2. 命令行使用
```bash
paddlestructure --image_dir=../doc/table/1.png
```
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .paddlestructure import PaddleStructure, draw_result, to_excel
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment