Unverified Commit e93735a2 authored by MissPenguin's avatar MissPenguin Committed by GitHub
Browse files

Merge pull request #3083 from WenmuZhou/table1

[DO NOT MERGE]Table
parents 6127aad9 b2260182
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# This is where we handle translating css styles into openpyxl styles
# and cascading those from parent to child in the dom.
from openpyxl.cell import cell
from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
from openpyxl.styles.fills import FILL_SOLID
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
from openpyxl.styles.colors import BLACK
FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
def colormap(color):
"""
Convenience for looking up known colors
"""
cmap = {'black': BLACK}
return cmap.get(color, color)
def style_string_to_dict(style):
"""
Convert css style string to a python dictionary
"""
def clean_split(string, delim):
return (s.strip() for s in string.split(delim))
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
return dict(styles)
def get_side(style, name):
return {'border_style': style.get('border-{}-style'.format(name)),
'color': colormap(style.get('border-{}-color'.format(name)))}
known_styles = {}
def style_dict_to_named_style(style_dict, number_format=None):
"""
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
style_and_format_string = str({
'style_dict': style_dict,
'parent': style_dict.parent,
'number_format': number_format,
})
if style_and_format_string not in known_styles:
# Font
font = Font(bold=style_dict.get('font-weight') == 'bold',
color=style_dict.get_color('color', None),
size=style_dict.get('font-size'))
# Alignment
alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
vertical=style_dict.get('vertical-align'),
wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
# Fill
bg_color = style_dict.get_color('background-color')
fg_color = style_dict.get_color('foreground-color', Color())
fill_type = style_dict.get('fill-type')
if bg_color and bg_color != 'transparent':
fill = PatternFill(fill_type=fill_type or FILL_SOLID,
start_color=bg_color,
end_color=fg_color)
else:
fill = PatternFill()
# Border
border = Border(left=Side(**get_side(style_dict, 'left')),
right=Side(**get_side(style_dict, 'right')),
top=Side(**get_side(style_dict, 'top')),
bottom=Side(**get_side(style_dict, 'bottom')),
diagonal=Side(**get_side(style_dict, 'diagonal')),
diagonal_direction=None,
outline=Side(**get_side(style_dict, 'outline')),
vertical=None,
horizontal=None)
name = 'Style {}'.format(len(known_styles) + 1)
pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
number_format=number_format)
known_styles[style_and_format_string] = pyxl_style
return known_styles[style_and_format_string]
class StyleDict(dict):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
def __init__(self, *args, **kwargs):
self.parent = kwargs.pop('parent', None)
super(StyleDict, self).__init__(*args, **kwargs)
def __getitem__(self, item):
if item in self:
return super(StyleDict, self).__getitem__(item)
elif self.parent:
return self.parent[item]
else:
raise KeyError('{} not found'.format(item))
def __hash__(self):
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
# Yielding the keys avoids creating unnecessary data structures
# and happily works with both python2 and python3 where the
# .keys() method is a dictionary_view in python3 and a list in python2.
def _keys(self):
yielded = set()
for k in self.keys():
yielded.add(k)
yield k
if self.parent:
for k in self.parent._keys():
if k not in yielded:
yielded.add(k)
yield k
def get(self, k, d=None):
try:
return self[k]
except KeyError:
return d
def get_color(self, k, d=None):
"""
Strip leading # off colors if necessary
"""
color = self.get(k, d)
if hasattr(color, 'startswith') and color.startswith('#'):
color = color[1:]
if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
color = ''.join(2 * c for c in color)
return color
class Element(object):
"""
Our base class for representing an html element along with a cascading style.
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
def __init__(self, element, parent=None):
self.element = element
self.number_format = None
parent_style = parent.style_dict if parent else None
self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
self._style_cache = None
def style(self):
"""
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if not self._style_cache:
self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
return self._style_cache
def get_dimension(self, dimension_key):
"""
Extracts the dimension from the style dict of the Element and returns it as a float.
"""
dimension = self.style_dict.get(dimension_key)
if dimension:
if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
dimension = dimension[:-2]
dimension = float(dimension)
return dimension
class Table(Element):
"""
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
def __init__(self, table):
"""
takes an html table object (from lxml)
"""
super(Table, self).__init__(table)
table_head = table.find('thead')
self.head = TableHead(table_head, parent=self) if table_head is not None else None
table_body = table.find('tbody')
self.body = TableBody(table_body if table_body is not None else table, parent=self)
class TableHead(Element):
"""
This class maps to the `<th>` element of the html table.
"""
def __init__(self, head, parent=None):
super(TableHead, self).__init__(head, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
class TableBody(Element):
"""
This class maps to the `<tbody>` element of the html table.
"""
def __init__(self, body, parent=None):
super(TableBody, self).__init__(body, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
class TableRow(Element):
"""
This class maps to the `<tr>` element of the html table.
"""
def __init__(self, tr, parent=None):
super(TableRow, self).__init__(tr, parent=parent)
self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
def element_to_string(el):
return _element_to_string(el).strip()
def _element_to_string(el):
string = ''
for x in el.iterchildren():
string += '\n' + _element_to_string(x)
text = el.text.strip() if el.text else ''
tail = el.tail.strip() if el.tail else ''
return text + string + '\n' + tail
class TableCell(Element):
"""
This class maps to the `<td>` element of the html table.
"""
CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
def __init__(self, cell, parent=None):
super(TableCell, self).__init__(cell, parent=parent)
self.value = element_to_string(cell)
self.number_format = self.get_number_format()
def data_type(self):
cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
if cell_types:
if 'TYPE_FORMULA' in cell_types:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
cell_type = 'TYPE_FORMULA'
elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
cell_type = 'TYPE_NUMERIC'
else:
cell_type = cell_types.pop()
else:
cell_type = 'TYPE_STRING'
return getattr(cell, cell_type)
def get_number_format(self):
if 'TYPE_CURRENCY' in self.element.get('class', '').split():
return FORMAT_CURRENCY_USD_SIMPLE
if 'TYPE_INTEGER' in self.element.get('class', '').split():
return '#,##0'
if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
return FORMAT_PERCENTAGE
if 'TYPE_DATE' in self.element.get('class', '').split():
return FORMAT_DATE_MMDDYYYY
if self.data_type() == cell.TYPE_NUMERIC:
try:
int(self.value)
except ValueError:
return '#,##0.##'
else:
return '#,##0'
def format(self, cell):
cell.style = self.style()
data_type = self.data_type()
if data_type:
cell.data_type = data_type
\ No newline at end of file
# Do imports like python3 so our package works for 2 and 3
from __future__ import absolute_import
from lxml import html
from openpyxl import Workbook
from openpyxl.utils import get_column_letter
from premailer import Premailer
from tablepyxl.style import Table
def string_to_int(s):
if s.isdigit():
return int(s)
return 0
def get_Tables(doc):
tree = html.fromstring(doc)
comments = tree.xpath('//comment()')
for comment in comments:
comment.drop_tag()
return [Table(table) for table in tree.xpath('//table')]
def write_rows(worksheet, elem, row, column=1):
"""
Writes every tr child element of elem to a row in the worksheet
returns the next row after all rows are written
"""
from openpyxl.cell.cell import MergedCell
initial_column = column
for table_row in elem.rows:
for table_cell in table_row.cells:
cell = worksheet.cell(row=row, column=column)
while isinstance(cell, MergedCell):
column += 1
cell = worksheet.cell(row=row, column=column)
colspan = string_to_int(table_cell.element.get("colspan", "1"))
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
if rowspan > 1 or colspan > 1:
worksheet.merge_cells(start_row=row, start_column=column,
end_row=row + rowspan - 1, end_column=column + colspan - 1)
cell.value = table_cell.value
table_cell.format(cell)
min_width = table_cell.get_dimension('min-width')
max_width = table_cell.get_dimension('max-width')
if colspan == 1:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
if max_width and width > max_width:
width = max_width
elif min_width and width < min_width:
width = min_width
worksheet.column_dimensions[get_column_letter(column)].width = width
column += colspan
row += 1
column = initial_column
return row
def table_to_sheet(table, wb):
"""
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
ws = wb.create_sheet(title=table.element.get('name'))
insert_table(table, ws, 1, 1)
def document_to_workbook(doc, wb=None, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document.
The workbook is returned
"""
if not wb:
wb = Workbook()
wb.remove(wb.active)
inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
tables = get_Tables(inline_styles_doc)
for table in tables:
table_to_sheet(table, wb)
return wb
def document_to_xl(doc, filename, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document. The workbook is written out to a file called filename
"""
wb = document_to_workbook(doc, base_url=base_url)
wb.save(filename)
def insert_table(table, worksheet, column, row):
if table.head:
row = write_rows(worksheet, table.head, row, column)
if table.body:
row = write_rows(worksheet, table.body, row, column)
def insert_table_at_cell(table, cell):
"""
Inserts a table at the location of an openpyxl Cell object.
"""
ws = cell.parent
column, row = cell.column, cell.row
insert_table(table, ws, column, row)
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
def init_args():
parser = infer_args()
# params for output
parser.add_argument("--output", type=str, default='./output/table')
# params for table structure
parser.add_argument("--structure_max_len", type=int, default=488)
parser.add_argument("--structure_model_dir", type=str)
parser.add_argument("--structure_char_type", type=str, default='en')
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def draw_result(image, result, font_path):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
if region['type'] == 'Table':
pass
elif region['type'] == 'Figure':
pass
else:
for box, rec_res in zip(region['res'][0], region['res'][1]):
boxes.append(np.array(box).reshape(-1, 2))
txts.append(rec_res[0])
scores.append(rec_res[1])
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
return im_show
\ No newline at end of file
...@@ -43,7 +43,7 @@ class TextDetector(object): ...@@ -43,7 +43,7 @@ class TextDetector(object):
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type 'limit_type': args.det_limit_type,
} }
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
......
...@@ -24,6 +24,7 @@ import cv2 ...@@ -24,6 +24,7 @@ import cv2
import copy import copy
import numpy as np import numpy as np
import time import time
import logging
from PIL import Image from PIL import Image
import tools.infer.utility as utility import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
...@@ -38,6 +39,9 @@ logger = get_logger() ...@@ -38,6 +39,9 @@ logger = get_logger()
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
...@@ -88,7 +92,7 @@ class TextSystem(object): ...@@ -88,7 +92,7 @@ class TextSystem(object):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
...@@ -104,11 +108,11 @@ class TextSystem(object): ...@@ -104,11 +108,11 @@ class TextSystem(object):
if self.use_angle_cls and cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
......
...@@ -109,11 +109,12 @@ def init_args(): ...@@ -109,11 +109,12 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=bool, default=False) parser.add_argument("--benchmark", type=bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True)
return parser return parser
...@@ -199,6 +200,8 @@ def create_predictor(args, mode, logger): ...@@ -199,6 +200,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir model_dir = args.cls_model_dir
elif mode == 'rec': elif mode == 'rec':
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'structure':
model_dir = args.structure_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -328,7 +331,9 @@ def create_predictor(args, mode, logger): ...@@ -328,7 +331,9 @@ def create_predictor(args, mode, logger):
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True)
if mode == 'structure':
config.switch_ir_optim(False)
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment