Commit c9171d1f authored by zhougaofeng's avatar zhougaofeng
Browse files

Update app.py, count_pdfs.py, LICENSE.md, magic-pdf.template.json,...

Update app.py, count_pdfs.py, LICENSE.md, magic-pdf.template.json, requirements.txt, requirements-docker.txt, requirements-qa.txt, update_version.py, setup.py, magic_pdf/__init__.py, magic_pdf/pdf_parse_by_txt.py, magic_pdf/pdf_parse_by_ocr.py, magic_pdf/user_api.py, magic_pdf/pdf_parse_union_core.py, magic_pdf/__pycache__/pdf_parse_by_ocr.cpython-310.pyc, magic_pdf/__pycache__/__init__.cpython-310.pyc, magic_pdf/__pycache__/pdf_parse_by_txt.cpython-310.pyc, magic_pdf/__pycache__/pdf_parse_union_core.cpython-310.pyc, magic_pdf/__pycache__/user_api.cpython-310.pyc, magic_pdf/dict2md/__init__.py, magic_pdf/dict2md/mkcontent.py, magic_pdf/dict2md/ocr_mkcontent.py, magic_pdf/dict2md/ocr_client.py, magic_pdf/dict2md/ocr_server.py, magic_pdf/dict2md/ocr_server_72.py, magic_pdf/dict2md/tmp.py, magic_pdf/dict2md/__pycache__/__init__.cpython-310.pyc, magic_pdf/dict2md/__pycache__/ocr_client.cpython-310.pyc, magic_pdf/dict2md/__pycache__/ocr_mkcontent.cpython-310.pyc, magic_pdf/filter/__init__.py, magic_pdf/filter/pdf_classify_by_type.py, magic_pdf/filter/pdf_meta_scan.py, magic_pdf/integrations/__init__.py, magic_pdf/integrations/rag/__init__.py, magic_pdf/integrations/rag/api.py, magic_pdf/integrations/rag/type.py, magic_pdf/integrations/rag/utils.py, magic_pdf/layout/__init__.py, magic_pdf/layout/bbox_sort.py, magic_pdf/layout/layout_det_utils.py, magic_pdf/layout/layout_sort.py, magic_pdf/layout/layout_spiler_recog.py, magic_pdf/layout/mcol_sort.py, magic_pdf/libs/__init__.py, magic_pdf/libs/boxbase.py, magic_pdf/libs/calc_span_stats.py, magic_pdf/libs/commons.py, magic_pdf/libs/config_reader.py, magic_pdf/libs/Constants.py, magic_pdf/libs/convert_utils.py, magic_pdf/libs/coordinate_transform.py, magic_pdf/libs/detect_language_from_model.py, magic_pdf/libs/draw_bbox.py, magic_pdf/libs/drop_reason.py, magic_pdf/libs/drop_tag.py, magic_pdf/libs/hash_utils.py, magic_pdf/libs/json_compressor.py, magic_pdf/libs/language.py, magic_pdf/libs/local_math.py, magic_pdf/libs/MakeContentConfig.py, magic_pdf/libs/markdown_utils.py, magic_pdf/libs/nlp_utils.py, magic_pdf/libs/ModelBlockTypeEnum.py, magic_pdf/libs/ocr_content_type.py, magic_pdf/libs/path_utils.py, magic_pdf/libs/pdf_check.py, magic_pdf/libs/pdf_image_tools.py, magic_pdf/libs/safe_filename.py, magic_pdf/libs/textbase.py, magic_pdf/libs/version.py, magic_pdf/libs/vis_utils.py, magic_pdf/model/__init__.py, magic_pdf/model/doc_analyze_by_custom_model.py, magic_pdf/model/magic_model.py, magic_pdf/model/pdf_extract_kit.py, magic_pdf/model/model_list.py, magic_pdf/model/pp_structure_v2.py, magic_pdf/model/ppTableModel.py, magic_pdf/model/pek_sub_modules/__init__.py, magic_pdf/model/pek_sub_modules/post_process.py, magic_pdf/model/pek_sub_modules/self_modify.py, magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py, magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py, magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py, magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py, magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py, magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py, magic_pdf/model/pek_sub_modules/structeqtable/__init__.py, magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py, magic_pdf/para/__init__.py, magic_pdf/para/block_continuation_processor.py, magic_pdf/para/block_termination_processor.py, magic_pdf/para/commons.py, magic_pdf/para/denoise.py, magic_pdf/para/draw.py, magic_pdf/para/exceptions.py, magic_pdf/para/layout_match_processor.py, magic_pdf/para/para_pipeline.py, magic_pdf/para/para_split.py, magic_pdf/para/para_split_v2.py, magic_pdf/para/raw_processor.py, magic_pdf/para/stats.py, magic_pdf/para/title_processor.py, magic_pdf/parse/__init__.py, magic_pdf/parse/common_parse.py, magic_pdf/parse/excel_parse.py, magic_pdf/parse/pdf_client.py, magic_pdf/pipe/__init__.py, magic_pdf/pipe/AbsPipe.py, magic_pdf/pipe/OCRPipe.py, magic_pdf/pipe/TXTPipe.py, magic_pdf/pipe/UNIPipe.py, magic_pdf/post_proc/__init__.py, magic_pdf/post_proc/pdf_post_filter.py, magic_pdf/post_proc/remove_footnote.py, magic_pdf/post_proc/detect_para.py, magic_pdf/pre_proc/__init__.py, magic_pdf/pre_proc/citationmarker_remove.py, magic_pdf/pre_proc/construct_page_dict.py, magic_pdf/pre_proc/cut_image.py, magic_pdf/pre_proc/detect_equation.py, magic_pdf/pre_proc/detect_footer_header_by_statistics.py, magic_pdf/pre_proc/detect_footer_by_model.py, magic_pdf/pre_proc/detect_footnote.py, magic_pdf/pre_proc/detect_header.py, magic_pdf/pre_proc/detect_images.py, magic_pdf/pre_proc/detect_page_number.py, magic_pdf/pre_proc/detect_tables.py, magic_pdf/pre_proc/equations_replace.py, magic_pdf/pre_proc/fix_image.py, magic_pdf/pre_proc/fix_table.py, magic_pdf/pre_proc/main_text_font.py, magic_pdf/pre_proc/ocr_detect_all_bboxes.py, magic_pdf/pre_proc/ocr_detect_layout.py, magic_pdf/pre_proc/ocr_dict_merge.py, magic_pdf/pre_proc/ocr_span_list_modify.py, magic_pdf/pre_proc/post_layout_split.py, magic_pdf/pre_proc/remove_bbox_overlap.py, magic_pdf/pre_proc/remove_colored_strip_bbox.py, magic_pdf/pre_proc/pdf_pre_filter.py, magic_pdf/pre_proc/remove_footer_header.py, magic_pdf/pre_proc/remove_rotate_bbox.py, magic_pdf/pre_proc/resolve_bbox_conflict.py, magic_pdf/pre_proc/solve_line_alien.py, magic_pdf/pre_proc/statistics.py, magic_pdf/resources/fasttext-langdetect/lid.176.ftz, magic_pdf/resources/model_config/model_configs.yaml, magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml, magic_pdf/resources/model_config/UniMERNet/demo.yaml, magic_pdf/rw/__init__.py, magic_pdf/rw/AbsReaderWriter.py, magic_pdf/rw/DiskReaderWriter.py, magic_pdf/rw/S3ReaderWriter.py, magic_pdf/spark/__init__.py, magic_pdf/spark/spark_api.py, magic_pdf/tools/__init__.py, magic_pdf/tools/pdf_client.py, magic_pdf/tools/common.py, magic_pdf/tools/cli_dev.py, magic_pdf/tools/cli.py, magic_pdf/tools/pdf_server.py files
parent 748e3b56
Pipeline #1783 canceled with stages
__version__ = "0.8.1"
from magic_pdf.libs.commons import fitz
import os
def draw_bbox_on_page(raw_pdf_doc: fitz.Document, paras_dict:dict, save_path: str):
"""
在page上画出bbox,保存到save_path
"""
# 检查文件是否存在
is_new_pdf = False
if os.path.exists(save_path):
# 打开现有的 PDF 文件
doc = fitz.open(save_path)
else:
# 创建一个新的空白 PDF 文件
is_new_pdf = True
doc = fitz.open('')
color_map = {
'image': fitz.pdfcolor["yellow"],
'text': fitz.pdfcolor['blue'],
"table": fitz.pdfcolor['green']
}
for k, v in paras_dict.items():
page_idx = v['page_idx']
width = raw_pdf_doc[page_idx].rect.width
height = raw_pdf_doc[page_idx].rect.height
new_page = doc.new_page(width=width, height=height)
shape = new_page.new_shape()
for order, block in enumerate(v['preproc_blocks']):
rect = fitz.Rect(block['bbox'])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=color_map['text'], fill_opacity=0.2)
shape.finish()
shape.commit()
for img in v['images']:
# 原始box画上去
rect = fitz.Rect(img['bbox'])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=fitz.pdfcolor['yellow'])
shape.finish()
shape.commit()
for img in v['image_backup']:
# 原始box画上去
rect = fitz.Rect(img['bbox'])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['yellow'], fill=None)
shape.finish()
shape.commit()
for tb in v['droped_text_block']:
# 原始box画上去
rect = fitz.Rect(tb['bbox'])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=fitz.pdfcolor['black'], fill_opacity=0.4)
shape.finish()
shape.commit()
# TODO table
for tb in v['tables']:
rect = fitz.Rect(tb['bbox'])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=fitz.pdfcolor['green'], fill_opacity=0.2)
shape.finish()
shape.commit()
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
if is_new_pdf:
doc.save(save_path)
else:
doc.saveIncr()
doc.close()
def debug_show_bbox(raw_pdf_doc: fitz.Document, page_idx: int, bboxes: list, droped_bboxes:list, expect_drop_bboxes:list, save_path: str, expected_page_id:int):
"""
以覆盖的方式写个临时的pdf,用于debug
"""
if page_idx!=expected_page_id:
return
if os.path.exists(save_path):
# 删除已经存在的文件
os.remove(save_path)
# 创建一个新的空白 PDF 文件
doc = fitz.open('')
width = raw_pdf_doc[page_idx].rect.width
height = raw_pdf_doc[page_idx].rect.height
new_page = doc.new_page(width=width, height=height)
shape = new_page.new_shape()
for bbox in bboxes:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=fitz.pdfcolor['blue'], fill_opacity=0.2)
shape.finish()
shape.commit()
for bbox in droped_bboxes:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=fitz.pdfcolor['yellow'], fill_opacity=0.2)
shape.finish()
shape.commit()
for bbox in expect_drop_bboxes:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=None)
shape.finish()
shape.commit()
# shape.insert_textbox(fitz.Rect(200, 0, 600, 20), f"total bboxes: {len(bboxes)}", fontname="helv", fontsize=12,
# color=(0, 0, 0))
# shape.finish(color=fitz.pdfcolor['black'])
# shape.commit()
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
doc.save(save_path)
doc.close()
def debug_show_page(page, bboxes1: list,bboxes2: list,bboxes3: list,):
save_path = "./tmp/debug.pdf"
if os.path.exists(save_path):
# 删除已经存在的文件
os.remove(save_path)
# 创建一个新的空白 PDF 文件
doc = fitz.open('')
width = page.rect.width
height = page.rect.height
new_page = doc.new_page(width=width, height=height)
shape = new_page.new_shape()
for bbox in bboxes1:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=fitz.pdfcolor['blue'], fill_opacity=0.2)
shape.finish()
shape.commit()
for bbox in bboxes2:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=None, fill=fitz.pdfcolor['yellow'], fill_opacity=0.2)
shape.finish()
shape.commit()
for bbox in bboxes3:
# 原始box画上去
rect = fitz.Rect(*bbox[0:4])
shape = new_page.new_shape()
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=None)
shape.finish()
shape.commit()
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
doc.save(save_path)
doc.close()
def draw_layout_bbox_on_page(raw_pdf_doc: fitz.Document, paras_dict:dict, header, footer, pdf_path: str):
"""
在page上画出bbox,保存到save_path
"""
# 检查文件是否存在
is_new_pdf = False
if os.path.exists(pdf_path):
# 打开现有的 PDF 文件
doc = fitz.open(pdf_path)
else:
# 创建一个新的空白 PDF 文件
is_new_pdf = True
doc = fitz.open('')
for k, v in paras_dict.items():
page_idx = v['page_idx']
layouts = v['layout_bboxes']
page = doc[page_idx]
shape = page.new_shape()
for order, layout in enumerate(layouts):
border_offset = 1
rect_box = layout['layout_bbox']
layout_label = layout['layout_label']
fill_color = fitz.pdfcolor['pink'] if layout_label=='U' else None
rect_box = [rect_box[0]+1, rect_box[1]-border_offset, rect_box[2]-1, rect_box[3]+border_offset]
rect = fitz.Rect(*rect_box)
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=fill_color, fill_opacity=0.4)
"""
draw order text on layout box
"""
font_size = 10
shape.insert_text((rect_box[0] + 1, rect_box[1] + font_size), f"{order}", fontsize=font_size, color=(0, 0, 0))
"""画上footer header"""
if header:
shape.draw_rect(fitz.Rect(header))
shape.finish(color=None, fill=fitz.pdfcolor['black'], fill_opacity=0.2)
if footer:
shape.draw_rect(fitz.Rect(footer))
shape.finish(color=None, fill=fitz.pdfcolor['black'], fill_opacity=0.2)
shape.commit()
if is_new_pdf:
doc.save(pdf_path)
else:
doc.saveIncr()
doc.close()
@DeprecationWarning
def draw_layout_on_page(raw_pdf_doc: fitz.Document, page_idx: int, page_layout: list, pdf_path: str):
"""
把layout的box用红色边框花在pdf_path的page_idx上
"""
def draw(shape, layout, fill_color=fitz.pdfcolor['pink']):
border_offset = 1
rect_box = layout['layout_bbox']
layout_label = layout['layout_label']
sub_layout = layout['sub_layout']
if len(sub_layout)==0:
fill_color = fill_color if layout_label=='U' else None
rect_box = [rect_box[0]+1, rect_box[1]-border_offset, rect_box[2]-1, rect_box[3]+border_offset]
rect = fitz.Rect(*rect_box)
shape.draw_rect(rect)
shape.finish(color=fitz.pdfcolor['red'], fill=fill_color, fill_opacity=0.2)
# if layout_label=='U':
# bad_boxes = layout.get("bad_boxes", [])
# for bad_box in bad_boxes:
# rect = fitz.Rect(*bad_box)
# shape.draw_rect(rect)
# shape.finish(color=fitz.pdfcolor['red'], fill=fitz.pdfcolor['red'], fill_opacity=0.2)
# else:
# rect = fitz.Rect(*rect_box)
# shape.draw_rect(rect)
# shape.finish(color=fitz.pdfcolor['blue'])
for sub_layout in sub_layout:
draw(shape, sub_layout)
shape.commit()
# 检查文件是否存在
is_new_pdf = False
if os.path.exists(pdf_path):
# 打开现有的 PDF 文件
doc = fitz.open(pdf_path)
else:
# 创建一个新的空白 PDF 文件
is_new_pdf = True
doc = fitz.open('')
page = doc[page_idx]
shape = page.new_shape()
for order, layout in enumerate(page_layout):
draw(shape, layout, fitz.pdfcolor['yellow'])
# shape.insert_textbox(fitz.Rect(200, 0, 600, 20), f"total bboxes: {len(layout)}", fontname="helv", fontsize=12,
# color=(0, 0, 0))
# shape.finish(color=fitz.pdfcolor['black'])
# shape.commit()
parent_dir = os.path.dirname(pdf_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
if is_new_pdf:
doc.save(pdf_path)
else:
doc.saveIncr()
doc.close()
\ No newline at end of file
__use_inside_model__ = True
__model_mode__ = "full"
import time
import fitz
import numpy as np
from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
def dict_compare(d1, d2):
return d1.items() == d2.items()
def remove_duplicates_dicts(lst):
unique_dicts = []
for dict_item in lst:
if not any(
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
):
unique_dicts.append(dict_item)
return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
try:
from PIL import Image
except ImportError:
logger.error("Pillow not installed, please install by pip.")
exit(1)
images = []
with fitz.open("pdf", pdf_bytes) as doc:
for index in range(0, doc.page_count):
page = doc[index]
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = page.get_pixmap(matrix=mat, alpha=False)
# If the width or height exceeds 9000 after scaling, do not scale further.
if pm.width > 9000 or pm.height > 9000:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height}
images.append(img_dict)
return images
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(self, ocr: bool, show_log: bool):
key = (ocr, show_log)
if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False):
model = None
if model_config.__model_mode__ == "lite":
logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is "
"not guaranteed to be reliable.")
model = MODEL.Paddle
elif model_config.__model_mode__ == "full":
# 使用 pdf_extract_kit
model = MODEL.PEK
if model_config.__use_inside_model__:
model_init_start = time.time()
if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
table_config = get_table_recog_config()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_config": table_config}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
exit(1)
model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}")
else:
logger.error("use_inside_model is False, not allow to use inside model")
exit(1)
return custom_model
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None):
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log)
images = load_images_from_pdf(pdf_bytes)
# end_page_id = end_page_id if end_page_id else len(images) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(images) - 1
if end_page_id > len(images) - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
model_json = []
doc_analyze_start = time.time()
for index, img_dict in enumerate(images):
img = img_dict["img"]
page_width = img_dict["width"]
page_height = img_dict["height"]
if start_page_id <= index <= end_page_id:
result = custom_model(img)
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
doc_analyze_cost = time.time() - doc_analyze_start
logger.info(f"doc analyze cost: {doc_analyze_cost}")
# logger.info(f'model_json:\n{model_json}')
return model_json
import json
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
calculate_overlap_area_in_bbox1_area_ratio,
get_overlap_area)
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
class MagicModel:
"""每个函数没有得到元素的时候返回空list."""
def __fix_axis(self):
for model_page_info in self.__model_list:
need_remove_list = []
page_no = model_page_info['page_info']['page_no']
horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(
model_page_info, self.__docs[page_no]
)
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det.get('bbox') is not None:
# 兼容直接输出bbox的模型数据,如paddle
x0, y0, x1, y1 = layout_det['bbox']
else:
# 兼容直接输出poly的模型数据,如xxx
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [
int(x0 / horizontal_scale_ratio),
int(y0 / vertical_scale_ratio),
int(x1 / horizontal_scale_ratio),
int(y1 / vertical_scale_ratio),
]
layout_det['bbox'] = bbox
# 删除高度或者宽度小于等于0的spans
if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
need_remove_list.append(layout_det)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['score'] <= 0.05:
need_remove_list.append(layout_det)
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_by_remove_high_iou_and_low_confidence(self):
for model_page_info in self.__model_list:
need_remove_list = []
layout_dets = model_page_info['layout_dets']
for layout_det1 in layout_dets:
for layout_det2 in layout_dets:
if layout_det1 == layout_det2:
continue
if layout_det1['category_id'] in [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __init__(self, model_list: list, docs: fitz.Document):
self.__model_list = model_list
self.__docs = docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self.__fix_axis()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
self.__fix_footnote()
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
for model_page_info in self.__model_list:
footnotes = []
figures = []
tables = []
for obj in model_page_info['layout_dets']:
if obj['category_id'] == 7:
footnotes.append(obj)
elif obj['category_id'] == 3:
figures.append(obj)
elif obj['category_id'] == 5:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
dis_figure_footnote = {}
dis_table_footnote = {}
for i in range(len(footnotes)):
for j in range(len(figures)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], figures[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_figure_footnote[i] = min(
bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']),
dis_figure_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
for j in range(len(tables)):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
footnotes[i]['bbox'], tables[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
dis_table_footnote[i] = min(
bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']),
dis_table_footnote.get(i, float('inf')),
)
for i in range(len(footnotes)):
if i not in dis_figure_footnote:
continue
if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]:
footnotes[i]['category_id'] = CategoryId.ImageFootnote
def __reduct_overlap(self, bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance(
self, page_no, subject_category_id, object_category_id
):
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret = []
MAX_DIS_OF_POINT = 10**9 + 7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def search_overlap_between_boxes(
subject_idx, object_idx
):
idxes = [subject_idx, object_idx]
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
merged_bbox = [
min(x0s),
min(y0s),
max(x1s),
max(y1s),
]
ratio = 0
other_objects = list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id']
not in (object_category_id, subject_category_id),
self.__model_list[page_no]['layout_dets'],
),
)
)
for other_object in other_objects:
ratio = max(
ratio,
get_overlap_area(
merged_bbox, other_object['bbox']
) * 1.0 / box_area(all_bboxes[object_idx]['bbox'])
)
if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO:
break
return ratio
def may_find_other_nearest_bbox(subject_idx, object_idx):
ret = float('inf')
x0 = min(
all_bboxes[subject_idx]['bbox'][0], all_bboxes[object_idx]['bbox'][0]
)
y0 = min(
all_bboxes[subject_idx]['bbox'][1], all_bboxes[object_idx]['bbox'][1]
)
x1 = max(
all_bboxes[subject_idx]['bbox'][2], all_bboxes[object_idx]['bbox'][2]
)
y1 = max(
all_bboxes[subject_idx]['bbox'][3], all_bboxes[object_idx]['bbox'][3]
)
object_area = abs(
all_bboxes[object_idx]['bbox'][2] - all_bboxes[object_idx]['bbox'][0]
) * abs(
all_bboxes[object_idx]['bbox'][3] - all_bboxes[object_idx]['bbox'][1]
)
for i in range(len(all_bboxes)):
if (
i == subject_idx
or all_bboxes[i]['category_id'] != subject_category_id
):
continue
if _is_part_overlap([x0, y0, x1, y1], all_bboxes[i]['bbox']) or _is_in(
all_bboxes[i]['bbox'], [x0, y0, x1, y1]
):
i_area = abs(
all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
) * abs(all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1])
if i_area >= object_area:
ret = min(float('inf'), dis[i][object_idx])
return ret
def expand_bbbox(idxes):
x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes]
y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes]
x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes]
y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes]
return min(x0s), min(y0s), max(x1s), max(y1s)
subjects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == subject_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
objects = self.__reduct_overlap(
list(
map(
lambda x: {'bbox': x['bbox'], 'score': x['score']},
filter(
lambda x: x['category_id'] == object_category_id,
self.__model_list[page_no]['layout_dets'],
),
)
)
)
subject_object_relation_map = {}
subjects.sort(
key=lambda x: x['bbox'][0] ** 2 + x['bbox'][1] ** 2
) # get the distance !
all_bboxes = []
for v in subjects:
all_bboxes.append(
{
'category_id': subject_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
for v in objects:
all_bboxes.append(
{
'category_id': object_category_id,
'bbox': v['bbox'],
'score': v['score'],
}
)
N = len(all_bboxes)
dis = [[MAX_DIS_OF_POINT] * N for _ in range(N)]
for i in range(N):
for j in range(i):
if (
all_bboxes[i]['category_id'] == subject_category_id
and all_bboxes[j]['category_id'] == subject_category_id
):
continue
subject_idx, object_idx = i, j
if all_bboxes[j]['category_id'] == subject_category_id:
subject_idx, object_idx = j, i
if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO:
dis[i][j] = float('inf')
dis[j][i] = dis[i][j]
continue
dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox'])
dis[j][i] = dis[i][j]
used = set()
for i in range(N):
# 求第 i 个 subject 所关联的 object
if all_bboxes[i]['category_id'] != subject_category_id:
continue
seen = set()
candidates = []
arr = []
for j in range(N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[j]['category_id'] != object_category_id
or j in used
or dis[i][j] == MAX_DIS_OF_POINT
):
continue
left, right, _, _ = bbox_relative_pos(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
) # 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if left or right:
one_way_dis = all_bboxes[i]['bbox'][2] - all_bboxes[i]['bbox'][0]
else:
one_way_dis = all_bboxes[i]['bbox'][3] - all_bboxes[i]['bbox'][1]
if dis[i][j] > one_way_dis:
continue
arr.append((dis[i][j], j))
arr.sort(key=lambda x: x[0])
if len(arr) > 0:
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if may_find_other_nearest_bbox(i, arr[0][1]) >= arr[0][0]:
candidates.append(arr[0][1])
seen.add(arr[0][1])
# 已经获取初始种子
for j in set(candidates):
tmp = []
for k in range(i + 1, N):
pos_flag_count = sum(
list(
map(
lambda x: 1 if x else 0,
bbox_relative_pos(
all_bboxes[j]['bbox'], all_bboxes[k]['bbox']
),
)
)
)
if pos_flag_count > 1:
continue
if (
all_bboxes[k]['category_id'] != object_category_id
or k in used
or k in seen
or dis[j][k] == MAX_DIS_OF_POINT
or dis[j][k] > dis[i][j]
):
continue
is_nearest = True
for ni in range(i + 1, N):
if ni in (j, k) or ni in used or ni in seen:
continue
if not float_gt(dis[ni][k], dis[j][k]):
is_nearest = False
break
if is_nearest:
nx0, ny0, nx1, ny1 = expand_bbbox(list(seen) + [k])
n_dis = bbox_distance(
all_bboxes[i]['bbox'], [nx0, ny0, nx1, ny1]
)
if float_gt(dis[i][j], n_dis):
continue
tmp.append(k)
seen.add(k)
candidates = tmp
if len(candidates) == 0:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0, oy0, ox1, oy1 = expand_bbbox(list(seen) + [i])
ix0, iy0, ix1, iy1 = all_bboxes[i]['bbox']
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses = [
[ox0, oy0, ix0, oy1],
[ox0, oy0, ox1, iy0],
[ox0, iy1, ox1, oy1],
[ix1, oy0, ox1, oy1],
]
caption_areas = []
for bbox in caption_poses:
embed_arr = []
for idx in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[idx]['bbox'], bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
embed_arr.append(idx)
if len(embed_arr) > 0:
embed_x0 = min([all_bboxes[idx]['bbox'][0] for idx in embed_arr])
embed_y0 = min([all_bboxes[idx]['bbox'][1] for idx in embed_arr])
embed_x1 = max([all_bboxes[idx]['bbox'][2] for idx in embed_arr])
embed_y1 = max([all_bboxes[idx]['bbox'][3] for idx in embed_arr])
caption_areas.append(
int(abs(embed_x1 - embed_x0) * abs(embed_y1 - embed_y0))
)
else:
caption_areas.append(0)
subject_object_relation_map[i] = []
if max(caption_areas) > 0:
max_area_idx = caption_areas.index(max(caption_areas))
caption_bbox = caption_poses[max_area_idx]
for j in seen:
if (
calculate_overlap_area_in_bbox1_area_ratio(
all_bboxes[j]['bbox'], caption_bbox
)
> CAPATION_OVERLAP_AREA_RATIO
):
used.add(j)
subject_object_relation_map[i].append(j)
for i in sorted(subject_object_relation_map.keys()):
result = {
'subject_body': all_bboxes[i]['bbox'],
'all': all_bboxes[i]['bbox'],
'score': all_bboxes[i]['score'],
}
if len(subject_object_relation_map[i]) > 0:
x0 = min(
[all_bboxes[j]['bbox'][0] for j in subject_object_relation_map[i]]
)
y0 = min(
[all_bboxes[j]['bbox'][1] for j in subject_object_relation_map[i]]
)
x1 = max(
[all_bboxes[j]['bbox'][2] for j in subject_object_relation_map[i]]
)
y1 = max(
[all_bboxes[j]['bbox'][3] for j in subject_object_relation_map[i]]
)
result['object_body'] = [x0, y0, x1, y1]
result['all'] = [
min(x0, all_bboxes[i]['bbox'][0]),
min(y0, all_bboxes[i]['bbox'][1]),
max(x1, all_bboxes[i]['bbox'][2]),
max(y1, all_bboxes[i]['bbox'][3]),
]
ret.append(result)
total_subject_object_dis = 0
# 计算已经配对的 distance 距离
for i in subject_object_relation_map.keys():
for j in subject_object_relation_map[i]:
total_subject_object_dis += bbox_distance(
all_bboxes[i]['bbox'], all_bboxes[j]['bbox']
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject = set(
[
key
for key in subject_object_relation_map.keys()
if len(subject_object_relation_map[i]) > 0
]
)
for i in range(N):
if all_bboxes[i]['category_id'] != object_category_id or i in used:
continue
candidates = []
for j in range(N):
if (
all_bboxes[j]['category_id'] != subject_category_id
or j in with_caption_subject
):
continue
candidates.append((dis[i][j], j))
if len(candidates) > 0:
candidates.sort(key=lambda x: x[0])
total_subject_object_dis += candidates[0][1]
with_caption_subject.add(j)
return ret, total_subject_object_dis
def get_imgs(self, page_no: int):
with_captions, _ = self.__tie_up_category_by_distance(page_no, 3, 4)
with_footnotes, _ = self.__tie_up_category_by_distance(
page_no, 3, CategoryId.ImageFootnote
)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'img_caption_bbox': with_captions[i].get('object_body', None),
'img_body_bbox': with_captions[i]['subject_body'],
'img_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_tables(
self, page_no: int
) -> list: # 3个坐标, caption, table主体,table-note
with_captions, _ = self.__tie_up_category_by_distance(page_no, 5, 6)
with_footnotes, _ = self.__tie_up_category_by_distance(page_no, 5, 7)
ret = []
N, M = len(with_captions), len(with_footnotes)
assert N == M
for i in range(N):
record = {
'score': with_captions[i]['score'],
'table_caption_bbox': with_captions[i].get('object_body', None),
'table_body_bbox': with_captions[i]['subject_body'],
'table_footnote_bbox': with_footnotes[i].get('object_body', None),
}
x0 = min(with_captions[i]['all'][0], with_footnotes[i]['all'][0])
y0 = min(with_captions[i]['all'][1], with_footnotes[i]['all'][1])
x1 = max(with_captions[i]['all'][2], with_footnotes[i]['all'][2])
y1 = max(with_captions[i]['all'][3], with_footnotes[i]['all'][3])
record['bbox'] = [x0, y0, x1, y1]
ret.append(record)
return ret
def get_equations(self, page_no: int) -> list: # 有坐标,也有字
inline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.EMBEDDING.value, page_no, ['latex']
)
interline_equations = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATED.value, page_no, ['latex']
)
interline_equations_blocks = self.__get_blocks_by_type(
ModelBlockTypeEnum.ISOLATE_FORMULA.value, page_no
)
return inline_equations, interline_equations, interline_equations_blocks
def get_discarded(self, page_no: int) -> list: # 自研模型,只有坐标
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.ABANDON.value, page_no)
return blocks
def get_text_blocks(self, page_no: int) -> list: # 自研模型搞的,只有坐标,没有字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.PLAIN_TEXT.value, page_no)
return blocks
def get_title_blocks(self, page_no: int) -> list: # 自研模型,只有坐标,没字
blocks = self.__get_blocks_by_type(ModelBlockTypeEnum.TITLE.value, page_no)
return blocks
def get_ocr_text(self, page_no: int) -> list: # paddle 搞的,有字也有坐标
text_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
for layout_det in layout_dets:
if layout_det['category_id'] == '15':
span = {
'bbox': layout_det['bbox'],
'content': layout_det['text'],
}
text_spans.append(span)
return text_spans
def get_all_spans(self, page_no: int) -> list:
def remove_duplicate_spans(spans):
new_spans = []
for span in spans:
if not any(span == existing_span for existing_span in new_spans):
new_spans.append(span)
return new_spans
all_spans = []
model_page_info = self.__model_list[page_no]
layout_dets = model_page_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
span['type'] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
if latex:
span['latex'] = latex
elif html:
span['html'] = html
span['type'] = ContentType.Table
elif category_id == 13:
span['content'] = layout_det['latex']
span['type'] = ContentType.InlineEquation
elif category_id == 14:
span['content'] = layout_det['latex']
span['type'] = ContentType.InterlineEquation
elif category_id == 15:
span['content'] = layout_det['text']
span['type'] = ContentType.Text
all_spans.append(span)
return remove_duplicate_spans(all_spans)
def get_page_size(self, page_no: int): # 获取页面宽高
# 获取当前页的page对象
page = self.__docs[page_no]
# 获取当前页的宽高
page_w = page.rect.width
page_h = page.rect.height
return page_w, page_h
def __get_blocks_by_type(
self, type: int, page_no: int, extra_col: list[str] = []
) -> list:
blocks = []
for page_dict in self.__model_list:
layout_dets = page_dict.get('layout_dets', [])
page_info = page_dict.get('page_info', {})
page_number = page_info.get('page_no', -1)
if page_no != page_number:
continue
for item in layout_dets:
category_id = item.get('category_id', -1)
bbox = item.get('bbox', None)
if category_id == type:
block = {
'bbox': bbox,
'score': item.get('score'),
}
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
def get_model_list(self, page_no):
return self.__model_list[page_no]
if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean')
if 0:
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
if 1:
model_list = json.loads(
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
)
pdf_bytes = drw.read(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
)
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7):
print(magic_model.get_imgs(i))
class MODEL:
Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
from loguru import logger
import os
import time
from magic_pdf.libs.Constants import *
from magic_pdf.model.model_list import AtomicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try:
import cv2
import yaml
import argparse
import numpy as np
import torch
# import torchtext
#
# if torchtext.__version__ >= "0.18.0":
# torchtext.disable_torchtext_deprecation_warning()
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
# from unimernet.common.config import Config
# import unimernet.tasks as tasks
# from unimernet.processors import load_processor
except ImportError as e:
logger.exception(e)
logger.error(
'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
else:
config = {
"model_dir": model_path,
"device": _device_
}
table_model = ppTableModel(config)
return table_model
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
# def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
# args = argparse.Namespace(cfg_path=cfg_path, options=None)
# cfg = Config(args)
# cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
# cfg.config.model.model_config.model_name = weight_dir
# cfg.config.model.tokenizer_config.path = weight_dir
# task = tasks.setup_task(cfg)
# model = task.build_model(cfg)
# model = model.to(_device_)
# vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
# mfr_transform = transforms.Compose([vis_processor, ])
# return [model, mfr_transform]
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
if atom_model_name not in self._models:
self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[atom_model_name]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
)
# elif model_name == AtomicModel.MFR:
# atom_model = mfr_model_init(
# kwargs.get("mfr_weight_dir"),
# kwargs.get("mfr_cfg_path"),
# kwargs.get("device")
# )
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_type"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
)
else:
logger.error("model name not allow")
exit(1)
return atom_model
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model)
current_dir = os.path.dirname(current_file_path)
# 上一级目录(magic_pdf)
root_dir = os.path.dirname(current_dir)
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
# table config
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_table = self.table_config.get("is_table_recog_enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案
self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
atom_model_manager = AtomModelSingleton()
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
# mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
# atom_model_name=AtomicModel.MFR,
# mfr_weight_dir=mfr_weight_dir,
# mfr_cfg_path=mfr_cfg_path,
# device=self.device
# )
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
# 初始化ocr
if self.apply_ocr:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
latex_filling_list = []
mf_image_list = []
# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection cost: {layout_cost}")
if self.apply_formula:
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
#logger.info(f'是否ocr识别:{self.apply_ocr}')
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# logger.info(f'------------------------------------orc_res:\n{ocr_res}\n------------------------------------')
# Integration results
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
#logger.info(f'是否表格识别:{self.apply_table}')
# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
#logger.info(f'------------------------------table_res\n{res}\n----------------------------------')
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
else:
html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
'end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
elif html_code:
res["html"] = html_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
return layout_res
# --------------------------------------------------------------------------------
# VIT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# CoaT: https://github.com/mlpc-ucsd/CoaT
# --------------------------------------------------------------------------------
import torch
from detectron2.layers import (
ShapeSpec,
)
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
from .deit import deit_base_patch16, mae_base_patch16
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
from transformers import AutoConfig
__all__ = [
"build_vit_fpn_backbone",
]
class VIT_Backbone(Backbone):
"""
Implement VIT backbone.
"""
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=None, image_only=False, cfg=None):
super().__init__()
self._out_features = out_features
if 'base' in name:
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
else:
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if name == 'beit_base_patch16':
model_func = beit_base_patch16
elif name == 'dit_base_patch16':
model_func = dit_base_patch16
elif name == "deit_base_patch16":
model_func = deit_base_patch16
elif name == "mae_base_patch16":
model_func = mae_base_patch16
elif name == "dit_large_patch16":
model_func = dit_large_patch16
elif name == "beit_large_patch16":
model_func = beit_large_patch16
if 'beit' in name or 'dit' in name:
if pos_type == "abs":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_abs_pos_emb=True,
**model_kwargs)
elif pos_type == "shared_rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_shared_rel_pos_bias=True,
**model_kwargs)
elif pos_type == "rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_rel_pos_bias=True,
**model_kwargs)
else:
raise ValueError()
elif "layoutlmv3" in name:
config = AutoConfig.from_pretrained(config_path)
# disable relative bias as DiT
config.has_spatial_attention_bias = False
config.has_relative_attention_bias = False
self.backbone = LayoutLMv3Model(config, detection=True,
out_features=out_features, image_only=image_only)
else:
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
**model_kwargs)
self.name = name
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
if "layoutlmv3" in self.name:
return self.backbone.forward(
input_ids=x["input_ids"] if "input_ids" in x else None,
bbox=x["bbox"] if "bbox" in x else None,
images=x["images"] if "images" in x else None,
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
# output_hidden_states=True,
)
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
return self.backbone.forward_features(x)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def build_VIT_backbone(cfg):
"""
Create a VIT instance from config.
Args:
cfg: a detectron2 CfgNode
Returns:
A VIT backbone instance.
"""
# fmt: off
name = cfg.MODEL.VIT.NAME
out_features = cfg.MODEL.VIT.OUT_FEATURES
drop_path = cfg.MODEL.VIT.DROP_PATH
img_size = cfg.MODEL.VIT.IMG_SIZE
pos_type = cfg.MODEL.VIT.POS_TYPE
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
if 'layoutlmv3' in name:
if cfg.MODEL.CONFIG_PATH != '':
config_path = cfg.MODEL.CONFIG_PATH
else:
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
else:
config_path = None
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
@BACKBONE_REGISTRY.register()
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Create a VIT w/ FPN backbone.
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_VIT_backbone(cfg)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import warnings
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, training_window_size=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None, training_window_size=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
training_window_size=training_window_size))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w = self.patch_shape[0]
self.num_patches_h = self.patch_shape[1]
# the so-called patch_shape is the patch shape during pre-training
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, position_embedding=None, **kwargs):
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
1, 2)
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
x = x + position_embedding
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, training_window_size):
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias
class BEiT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=[224, 224],
patch_size=16,
in_chans=3,
num_classes=80,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=None,
init_values=None,
use_abs_pos_emb=False,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_checkpoint=True,
pretrained=None,
out_features=None,
):
super(BEiT, self).__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_checkpoint = use_checkpoint
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# trunc_normal_(self.mask_token, std=.02)
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
logger = get_root_logger()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self,
filename=self.init_cfg['checkpoint'],
strict=False,
logger=logger,
beit_spec_expand_rel_pos = self.use_rel_pos_bias,
)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
# Hp, Wp are HW for patches
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.pos_embed is not None:
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
features = []
training_window_size = torch.tensor([Hp, Wp])
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
else:
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
if i in self.out_indices:
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def beit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def beit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.1,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=1e-5,
**kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
model = model.to("cuda:0")
input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
output1 = model(input1)
output2 = model(input2)
output3 = model(input3)
print("all done")
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import warnings
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
from functools import partial
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w, self.num_patches_h = self.window_size
self.num_patches = self.window_size[0] * self.window_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(
1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
model_name='vit_base_patch16_224',
img_size=384,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.1,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_cfg=None,
pos_embed_interp=False,
random_init=False,
align_corners=False,
use_checkpoint=False,
num_extra_tokens=1,
out_features=None,
**kwargs,
):
super(ViT, self).__init__()
self.model_name = model_name
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.num_classes = num_classes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.hybrid_backbone = hybrid_backbone
self.norm_layer = norm_layer
self.norm_cfg = norm_cfg
self.pos_embed_interp = pos_embed_interp
self.random_init = random_init
self.align_corners = align_corners
self.use_checkpoint = use_checkpoint
self.num_extra_tokens = num_extra_tokens
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
# self.num_stages = self.depth
# self.out_indices = tuple(range(self.num_stages))
if self.hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
if self.num_extra_tokens == 2:
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=self.drop_rate)
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
for i in range(self.depth)])
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.num_extra_tokens==2:
trunc_normal_(self.dist_token, std=0.2)
self.apply(self._init_weights)
# self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
logger = get_root_logger()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _conv_filter(self, state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def to_2D(self, x):
n, hw, c = x.shape
h = w = int(math.sqrt(hw))
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def to_1D(self, x):
n, c, h, w = x.shape
x = x.reshape(n, c, -1).transpose(1, 2)
return x
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_extra_tokens
N = self.pos_embed.shape[1] - self.num_extra_tokens
if npatch == N and w == h:
return self.pos_embed
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
all_tokens = [self.cls_token.expand(B, -1, -1)]
if self.num_extra_tokens == 2:
dist_tokens = self.dist_token.expand(B, -1, -1)
all_tokens.append(dist_tokens)
all_tokens.append(x)
x = torch.cat(all_tokens, dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
# print(f"==========shape of x is {x.shape}==========")
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if i in self.out_indices:
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def deit_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=2,
**kwargs)
model.default_cfg = _cfg()
return model
def mae_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=1,
**kwargs)
model.default_cfg = _cfg()
return model
\ No newline at end of file
from .models import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
'''
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
'''
import json
import os
from pathlib import Path
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{park2019cord,
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
year={2019}
}
"""
_DESCRIPTION = """\
https://github.com/clovaai/cord/
"""
def quad_to_box(quad):
# test 87 is wrongly annotated
box = (
max(0, quad["x1"]),
max(0, quad["y1"]),
quad["x3"],
quad["y3"]
)
if box[3] < box[1]:
bbox = list(box)
tmp = bbox[3]
bbox[3] = bbox[1]
bbox[1] = tmp
box = tuple(bbox)
if box[2] < box[0]:
bbox = list(box)
tmp = bbox[2]
bbox[2] = bbox[0]
bbox[0] = tmp
box = tuple(bbox)
return box
def _get_drive_url(url):
base_url = 'https://drive.google.com/uc?id='
split_url = url.split('/')
return base_url + split_url[5]
_URLS = [
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
# If you failed to download the dataset through the automatic downloader,
# you can download it manually and modify the code to get the local dataset.
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
]
class CordConfig(datasets.BuilderConfig):
"""BuilderConfig for CORD"""
def __init__(self, **kwargs):
"""BuilderConfig for CORD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CordConfig, self).__init__(**kwargs)
class Cord(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
citation=_CITATION,
homepage="https://github.com/clovaai/cord/",
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
"""Uses local files located with data_dir"""
downloaded_file = dl_manager.download_and_extract(_URLS)
# move files from the second URL together with files from the first one.
dest = Path(downloaded_file[0])/"CORD"
for split in ["train", "dev", "test"]:
for file_type in ["image", "json"]:
if split == "test" and file_type == "json":
continue
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
for f in files:
os.rename(f, dest/split/file_type/f.name)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "json")
img_dir = os.path.join(filepath, "image")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["valid_line"]:
cur_line_bboxes = []
line_words, label = item["words"], item["category"]
line_words = [w for w in line_words if w["text"].strip() != ""]
if len(line_words) == 0:
continue
if label == "other":
for w in line_words:
words.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
else:
words.append(line_words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
for w in line_words[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
bboxes.extend(cur_line_bboxes)
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizerBase
from transformers.data.data_collator import (
DataCollatorMixin,
_torch_collate_batch,
)
from transformers.file_utils import PaddingStrategy
from typing import NewType
InputDataClass = NewType("InputDataClass", Any)
def pre_calc_rel_mat(segment_ids):
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
device=segment_ids.device, dtype=torch.bool)
for i in range(segment_ids.shape[0]):
for j in range(segment_ids.shape[1]):
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
return valid_span
@dataclass
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
images = None
if "images" in features[0]:
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if images is not None:
batch["images"] = images
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
for k, v in batch.items()}
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
if labels is None:
return batch
has_bbox_input = "bbox" in features[0]
has_position_input = "position_ids" in features[0]
padding_idx=self.tokenizer.pad_token_id
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
for position_id in batch["position_ids"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+ position_id for position_id in batch["position_ids"]]
if 'segment_ids' in batch:
assert 'position_ids' in batch
for i in range(len(batch['segment_ids'])):
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if 'segment_ids' in batch:
valid_span = pre_calc_rel_mat(
segment_ids=batch['segment_ids']
)
batch['valid_span'] = valid_span
del batch['segment_ids']
if images is not None:
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
return batch
# coding=utf-8
'''
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
'''
import json
import os
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""
class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""
def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)
class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""
BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
cur_line_bboxes = []
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(w["box"], size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
# box = normalize_bbox(item["box"], size)
# cur_line_bboxes = [box for _ in range(len(words))]
bboxes.extend(cur_line_bboxes)
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
\ No newline at end of file
import torchvision.transforms.functional as F
import warnings
import math
import random
import numpy as np
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
def crop(image, i, j, h, w, boxes=None):
cropped_image = F.crop(image, i, j, h, w)
if boxes is not None:
# Currently we cannot use this case since when some boxes is out of the cropped image,
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
# which haven't been implemented here
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
boxes = cropped_boxes.reshape(-1, 4)
return cropped_image, boxes
def resize(image, size, interpolation, boxes=None):
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
# which is compatible with a square image size of 224x224
rescaled_image = F.resize(image, size, interpolation)
if boxes is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
# boxes = boxes.copy()
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
return rescaled_image, scaled_boxes
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def get_bb(bb, page_size):
bbs = [float(j) for j in bb]
xs, ys = [], []
for i, b in enumerate(bbs):
if i % 2 == 0:
xs.append(b)
else:
ys.append(b)
(width, height) = page_size
return_bb = [
clamp(min(xs), 0, width - 1),
clamp(min(ys), 0, height - 1),
clamp(max(xs), 0, width - 1),
clamp(max(ys), 0, height - 1),
]
return_bb = [
int(1000 * return_bb[0] / width),
int(1000 * return_bb[1] / height),
int(1000 * return_bb[2] / width),
int(1000 * return_bb[3] / height),
]
return return_bb
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return F.InterpolationMode.BICUBIC
elif method == 'lanczos':
return F.InterpolationMode.LANCZOS
elif method == 'hamming':
return F.InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return F.InterpolationMode.BILINEAR
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, augmentation=False, box=None):
for t in self.transforms:
img = t(img, augmentation, box)
return img
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, augmentation=False, box=None):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
if augmentation:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.crop(img, i, j, h, w)
# img, box = crop(img, i, j, h, w, box)
img = F.resize(img, self.size, self.interpolation)
second_img = F.resize(img, self.second_size, self.second_interpolation) \
if self.second_size is not None else None
return img, second_img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
import os
import json
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
XFund_label2ids = {
"O":0,
'B-HEADER':1,
'I-HEADER':2,
'B-QUESTION':3,
'I-QUESTION':4,
'B-ANSWER':5,
'I-ANSWER':6,
}
class xfund_dataset(Dataset):
def box_norm(self, box, width, height):
def clip(min_num, num, max_num):
return min(max(num, min_num), max_num)
x0, y0, x1, y1 = box
x0 = clip(0, int((x0 / width) * 1000), 1000)
y0 = clip(0, int((y0 / height) * 1000), 1000)
x1 = clip(0, int((x1 / width) * 1000), 1000)
y1 = clip(0, int((y1 / height) * 1000), 1000)
assert x1 >= x0
assert y1 >= y0
return [x0, y0, x1, y1]
def get_segment_ids(self, bboxs):
segment_ids = []
for i in range(len(bboxs)):
if i == 0:
segment_ids.append(0)
else:
if bboxs[i - 1] == bboxs[i]:
segment_ids.append(segment_ids[-1])
else:
segment_ids.append(segment_ids[-1] + 1)
return segment_ids
def get_position_ids(self, segment_ids):
position_ids = []
for i in range(len(segment_ids)):
if i == 0:
position_ids.append(2)
else:
if segment_ids[i] == segment_ids[i - 1]:
position_ids.append(position_ids[-1] + 1)
else:
position_ids.append(2)
return position_ids
def load_data(
self,
data_file,
):
# re-org data format
total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
for i in range(len(data_file['documents'])):
width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
'height']
cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
for j in range(len(data_file['documents'][i]['document'])):
cur_item = data_file['documents'][i]['document'][j]
cur_doc_lines.append(cur_item['text'])
cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
cur_doc_ner_tags.append(cur_item['label'])
total_data['id'] += [len(total_data['id'])]
total_data['lines'] += [cur_doc_lines]
total_data['bboxes'] += [cur_doc_bboxes]
total_data['ner_tags'] += [cur_doc_ner_tags]
total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
# tokenize text and get bbox/label
total_input_ids, total_bboxs, total_label_ids = [], [], []
for i in range(len(total_data['lines'])):
cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
for j in range(len(total_data['lines'][i])):
cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
if len(cur_input_ids) == 0: continue
cur_label = total_data['ner_tags'][i][j].upper()
if cur_label == 'OTHER':
cur_labels = ["O"] * len(cur_input_ids)
for k in range(len(cur_labels)):
cur_labels[k] = self.label2ids[cur_labels[k]]
else:
cur_labels = [cur_label] * len(cur_input_ids)
cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
for k in range(1, len(cur_labels)):
cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
cur_doc_input_ids += cur_input_ids
cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
cur_doc_labels += cur_labels
assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
assert len(cur_doc_input_ids) > 0
total_input_ids.append(cur_doc_input_ids)
total_bboxs.append(cur_doc_bboxs)
total_label_ids.append(cur_doc_labels)
assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
# split text to several slices because of over-length
input_ids, bboxs, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
for i in range(len(total_input_ids)):
start = 0
cur_iter = 0
while start < len(total_input_ids[i]):
end = min(start + 510, len(total_input_ids[i]))
input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
labels.append([-100] + total_label_ids[i][start: end] + [-100])
cur_segment_ids = self.get_segment_ids(bboxs[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
start = end
cur_iter += 1
assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)
res = {
'input_ids': input_ids,
'bbox': bboxs,
'labels': labels,
'segment_ids': segment_ids,
'position_ids': position_ids,
'image_path': image_path,
}
return res
def __init__(
self,
args,
tokenizer,
mode
):
self.args = args
self.mode = mode
self.cur_la = args.language
self.tokenizer = tokenizer
self.label2ids = XFund_label2ids
self.common_transform = Compose([
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, interpolation=args.train_interpolation,
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor((0.5, 0.5, 0.5)),
std=torch.tensor((0.5, 0.5, 0.5)))
])
data_file = json.load(
open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
'r'))
self.feature = self.load_data(data_file)
def __len__(self):
return len(self.feature['input_ids'])
def __getitem__(self, index):
input_ids = self.feature["input_ids"][index]
# attention_mask = self.feature["attention_mask"][index]
attention_mask = [1] * len(input_ids)
labels = self.feature["labels"][index]
bbox = self.feature["bbox"][index]
segment_ids = self.feature['segment_ids'][index]
position_ids = self.feature['position_ids'][index]
img = pil_loader(self.feature['image_path'][index])
for_patches, _ = self.common_transform(img, augmentation=False)
patch = self.patch_transform(for_patches)
assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"bbox": bbox,
"segment_ids": segment_ids,
"position_ids": position_ids,
"images": patch,
}
return res
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
\ No newline at end of file
from .layoutlmv3 import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment