Commit 826086d2 authored by zhougaofeng's avatar zhougaofeng
Browse files

Deleted magic_pdf/__pycache__/__init__.cpython-310.pyc,...

Deleted magic_pdf/__pycache__/__init__.cpython-310.pyc, magic_pdf/__pycache__/pdf_parse_by_ocr.cpython-310.pyc, magic_pdf/__pycache__/pdf_parse_by_txt.cpython-310.pyc, magic_pdf/__pycache__/pdf_parse_union_core.cpython-310.pyc, magic_pdf/__pycache__/user_api.cpython-310.pyc, magic_pdf/dict2md/__pycache__/__init__.cpython-310.pyc, magic_pdf/dict2md/__pycache__/ocr_client.cpython-310.pyc, magic_pdf/dict2md/__pycache__/ocr_mkcontent.cpython-310.pyc, magic_pdf/dict2md/__init__.py, magic_pdf/dict2md/mkcontent.py, magic_pdf/dict2md/ocr_client.py, magic_pdf/dict2md/ocr_mkcontent.py, magic_pdf/dict2md/ocr_server.py, magic_pdf/filter/__init__.py, magic_pdf/filter/pdf_classify_by_type.py, magic_pdf/filter/pdf_meta_scan.py, magic_pdf/integrations/rag/__init__.py, magic_pdf/integrations/rag/api.py, magic_pdf/integrations/rag/type.py, magic_pdf/integrations/rag/utils.py, magic_pdf/integrations/__init__.py, magic_pdf/layout/__init__.py, magic_pdf/layout/bbox_sort.py, magic_pdf/layout/layout_det_utils.py, magic_pdf/layout/layout_sort.py, magic_pdf/layout/layout_spiler_recog.py, magic_pdf/layout/mcol_sort.py, magic_pdf/libs/Constants.py, magic_pdf/libs/MakeContentConfig.py, magic_pdf/libs/ModelBlockTypeEnum.py, magic_pdf/libs/__init__.py, magic_pdf/libs/boxbase.py, magic_pdf/libs/calc_span_stats.py, magic_pdf/libs/commons.py, magic_pdf/libs/config_reader.py, magic_pdf/libs/convert_utils.py, magic_pdf/libs/coordinate_transform.py, magic_pdf/libs/detect_language_from_model.py, magic_pdf/libs/draw_bbox.py, magic_pdf/libs/drop_reason.py, magic_pdf/libs/drop_tag.py, magic_pdf/libs/hash_utils.py, magic_pdf/libs/json_compressor.py, magic_pdf/libs/language.py, magic_pdf/libs/local_math.py, magic_pdf/libs/markdown_utils.py, magic_pdf/libs/nlp_utils.py, magic_pdf/libs/ocr_content_type.py, magic_pdf/libs/path_utils.py, magic_pdf/libs/pdf_check.py, magic_pdf/libs/pdf_image_tools.py, magic_pdf/libs/safe_filename.py, magic_pdf/libs/textbase.py, magic_pdf/libs/version.py, magic_pdf/libs/vis_utils.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py, magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py, magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py, magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py, magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py, magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py, magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py, magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py, magic_pdf/model/pek_sub_modules/structeqtable/__init__.py, magic_pdf/model/pek_sub_modules/__init__.py, magic_pdf/model/pek_sub_modules/post_process.py, magic_pdf/model/pek_sub_modules/self_modify.py, magic_pdf/model/__init__.py, magic_pdf/model/doc_analyze_by_custom_model.py, magic_pdf/model/magic_model.py, magic_pdf/model/model_list.py, magic_pdf/model/pdf_extract_kit.py, magic_pdf/model/ppTableModel.py, magic_pdf/model/pp_structure_v2.py, magic_pdf/para/__init__.py, magic_pdf/para/block_continuation_processor.py, magic_pdf/para/block_termination_processor.py, magic_pdf/para/commons.py, magic_pdf/para/denoise.py, magic_pdf/para/draw.py, magic_pdf/para/exceptions.py, magic_pdf/para/layout_match_processor.py, magic_pdf/para/para_pipeline.py, magic_pdf/para/para_split.py, magic_pdf/para/para_split_v2.py, magic_pdf/para/raw_processor.py, magic_pdf/para/stats.py, magic_pdf/para/title_processor.py, magic_pdf/parse/__init__.py, magic_pdf/parse/common_parse.py, magic_pdf/parse/excel_parse.py, magic_pdf/parse/pdf_client.py, magic_pdf/pipe/AbsPipe.py, magic_pdf/pipe/OCRPipe.py, magic_pdf/pipe/TXTPipe.py, magic_pdf/pipe/UNIPipe.py, magic_pdf/pipe/__init__.py, magic_pdf/post_proc/__init__.py, magic_pdf/post_proc/detect_para.py, magic_pdf/post_proc/pdf_post_filter.py, magic_pdf/post_proc/remove_footnote.py, magic_pdf/pre_proc/__init__.py, magic_pdf/pre_proc/citationmarker_remove.py, magic_pdf/pre_proc/construct_page_dict.py, magic_pdf/pre_proc/cut_image.py, magic_pdf/pre_proc/detect_equation.py, magic_pdf/pre_proc/detect_footer_by_model.py, magic_pdf/pre_proc/detect_footer_header_by_statistics.py, magic_pdf/pre_proc/detect_footnote.py, magic_pdf/pre_proc/detect_header.py, magic_pdf/pre_proc/detect_images.py, magic_pdf/pre_proc/detect_page_number.py, magic_pdf/pre_proc/detect_tables.py, magic_pdf/pre_proc/equations_replace.py, magic_pdf/pre_proc/fix_image.py, magic_pdf/pre_proc/fix_table.py, magic_pdf/pre_proc/main_text_font.py, magic_pdf/pre_proc/ocr_detect_all_bboxes.py, magic_pdf/pre_proc/ocr_detect_layout.py, magic_pdf/pre_proc/ocr_dict_merge.py, magic_pdf/pre_proc/ocr_span_list_modify.py, magic_pdf/pre_proc/pdf_pre_filter.py, magic_pdf/pre_proc/post_layout_split.py, magic_pdf/pre_proc/remove_bbox_overlap.py, magic_pdf/pre_proc/remove_colored_strip_bbox.py, magic_pdf/pre_proc/remove_footer_header.py, magic_pdf/pre_proc/remove_rotate_bbox.py, magic_pdf/pre_proc/resolve_bbox_conflict.py, magic_pdf/pre_proc/solve_line_alien.py, magic_pdf/pre_proc/statistics.py, magic_pdf/resources/fasttext-langdetect/lid.176.ftz, magic_pdf/resources/model_config/UniMERNet/demo.yaml, magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml, magic_pdf/resources/model_config/model_configs.yaml, magic_pdf/rw/AbsReaderWriter.py, magic_pdf/rw/DiskReaderWriter.py, magic_pdf/rw/S3ReaderWriter.py, magic_pdf/rw/__init__.py, magic_pdf/spark/__init__.py, magic_pdf/spark/spark_api.py, magic_pdf/tools/__init__.py, magic_pdf/tools/cli.py, magic_pdf/tools/cli_dev.py, magic_pdf/tools/common.py, magic_pdf/tools/pdf_server.py, magic_pdf/__init__.py, magic_pdf/config.ini, magic_pdf/pdf_parse_by_ocr.py, magic_pdf/pdf_parse_by_txt.py, magic_pdf/pdf_parse_union_core.py, magic_pdf/user_api.py files
parent 57aaa1cf
class MODEL:
Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class AtomicModel:
Layout = "layout"
MFD = "mfd"
MFR = "mfr"
OCR = "ocr"
Table = "table"
from loguru import logger
import os
import time
from magic_pdf.libs.Constants import *
from magic_pdf.model.model_list import AtomicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try:
import cv2
import yaml
import argparse
import numpy as np
import torch
# import torchtext
#
# if torchtext.__version__ >= "0.18.0":
# torchtext.disable_torchtext_deprecation_warning()
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
# from unimernet.common.config import Config
# import unimernet.tasks as tasks
# from unimernet.processors import load_processor
except ImportError as e:
logger.exception(e)
logger.error(
'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
else:
config = {
"model_dir": model_path,
"device": _device_
}
table_model = ppTableModel(config)
return table_model
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
# def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
# args = argparse.Namespace(cfg_path=cfg_path, options=None)
# cfg = Config(args)
# cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
# cfg.config.model.model_config.model_name = weight_dir
# cfg.config.model.tokenizer_config.path = weight_dir
# task = tasks.setup_task(cfg)
# model = task.build_model(cfg)
# model = model.to(_device_)
# vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
# mfr_transform = transforms.Compose([vis_processor, ])
# return [model, mfr_transform]
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
return model
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
if atom_model_name not in self._models:
self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[atom_model_name]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
)
# elif model_name == AtomicModel.MFR:
# atom_model = mfr_model_init(
# kwargs.get("mfr_weight_dir"),
# kwargs.get("mfr_cfg_path"),
# kwargs.get("device")
# )
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_type"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
)
else:
logger.error("model name not allow")
exit(1)
return atom_model
class CustomPEKModel:
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
"""
======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model)
current_dir = os.path.dirname(current_file_path)
# 上一级目录(magic_pdf)
root_dir = os.path.dirname(current_dir)
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
# table config
self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
self.apply_table = self.table_config.get("is_table_recog_enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_type = self.table_config.get("model", TABLE_MASTER)
self.apply_ocr = ocr
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
)
)
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案
self.device = kwargs.get("device", self.configs["config"]["device"])
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
atom_model_manager = AtomModelSingleton()
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
# mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
# atom_model_name=AtomicModel.MFR,
# mfr_weight_dir=mfr_weight_dir,
# mfr_cfg_path=mfr_cfg_path,
# device=self.device
# )
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
)
# 初始化ocr
if self.apply_ocr:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_type]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_type=self.table_model_type,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!')
def __call__(self, image,index,end_page_id):
latex_filling_list = []
mf_image_list = []
# layout检测
layout_start = time.time()
layout_res = self.layout_model(image, ignore_catids=[])
layout_cost = round(time.time() - layout_start, 2)
# logger.info(f"layout detection cost: {layout_cost}")
total_cost = layout_cost
if self.apply_formula:
# 公式检测
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
mfr_cost = round(time.time() - mfr_start, 2)
# logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
#logger.info(f'table_res_list:\n{table_res_list}')
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
#logger.info(f'是否ocr识别:{self.apply_ocr}')
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
# logger.info(f'------new_image:{new_image}')
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
#logger.info(f'new_image:{new_image}')
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# logger.info(f'------------------------------------orc_res:\n{ocr_res}\n------------------------------------')
# Integration results
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2)
# logger.info(f"ocr cost: {ocr_cost}")
total_cost = round(total_cost + ocr_cost,2)
index = index + 1
end_page_id = end_page_id + 1
logger.info(f'当前解析第【{index} / {end_page_id}】页, 耗时:{total_cost}')
#logger.info(f'是否表格识别:{self.apply_table}')
# 表格识别 table recognition
if self.apply_table:
table_start = time.time()
for res in table_res_list:
#logger.info(f'------------------------------table_res\n{res}\n----------------------------------')
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_type == STRUCT_EQTABLE:
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
else:
html_code = self.table_model.img2html(new_image)
run_time = time.time() - single_table_start_time
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
# 判断是否返回正常
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
'end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"------------table recognition processing fails----------")
elif html_code:
res["html"] = html_code
else:
logger.warning(f"------------table recognition processing fails----------")
table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
#logger.info(f'layout_res:{layout_res}')
return layout_res
# --------------------------------------------------------------------------------
# VIT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# CoaT: https://github.com/mlpc-ucsd/CoaT
# --------------------------------------------------------------------------------
import torch
from detectron2.layers import (
ShapeSpec,
)
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
from .deit import deit_base_patch16, mae_base_patch16
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
from transformers import AutoConfig
__all__ = [
"build_vit_fpn_backbone",
]
class VIT_Backbone(Backbone):
"""
Implement VIT backbone.
"""
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=None, image_only=False, cfg=None):
super().__init__()
self._out_features = out_features
if 'base' in name:
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
else:
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if name == 'beit_base_patch16':
model_func = beit_base_patch16
elif name == 'dit_base_patch16':
model_func = dit_base_patch16
elif name == "deit_base_patch16":
model_func = deit_base_patch16
elif name == "mae_base_patch16":
model_func = mae_base_patch16
elif name == "dit_large_patch16":
model_func = dit_large_patch16
elif name == "beit_large_patch16":
model_func = beit_large_patch16
if 'beit' in name or 'dit' in name:
if pos_type == "abs":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_abs_pos_emb=True,
**model_kwargs)
elif pos_type == "shared_rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_shared_rel_pos_bias=True,
**model_kwargs)
elif pos_type == "rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_rel_pos_bias=True,
**model_kwargs)
else:
raise ValueError()
elif "layoutlmv3" in name:
config = AutoConfig.from_pretrained(config_path)
# disable relative bias as DiT
config.has_spatial_attention_bias = False
config.has_relative_attention_bias = False
self.backbone = LayoutLMv3Model(config, detection=True,
out_features=out_features, image_only=image_only)
else:
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
**model_kwargs)
self.name = name
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
if "layoutlmv3" in self.name:
return self.backbone.forward(
input_ids=x["input_ids"] if "input_ids" in x else None,
bbox=x["bbox"] if "bbox" in x else None,
images=x["images"] if "images" in x else None,
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
# output_hidden_states=True,
)
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
return self.backbone.forward_features(x)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def build_VIT_backbone(cfg):
"""
Create a VIT instance from config.
Args:
cfg: a detectron2 CfgNode
Returns:
A VIT backbone instance.
"""
# fmt: off
name = cfg.MODEL.VIT.NAME
out_features = cfg.MODEL.VIT.OUT_FEATURES
drop_path = cfg.MODEL.VIT.DROP_PATH
img_size = cfg.MODEL.VIT.IMG_SIZE
pos_type = cfg.MODEL.VIT.POS_TYPE
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
if 'layoutlmv3' in name:
if cfg.MODEL.CONFIG_PATH != '':
config_path = cfg.MODEL.CONFIG_PATH
else:
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
else:
config_path = None
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
@BACKBONE_REGISTRY.register()
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Create a VIT w/ FPN backbone.
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_VIT_backbone(cfg)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import warnings
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, training_window_size=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None, training_window_size=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
training_window_size=training_window_size))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w = self.patch_shape[0]
self.num_patches_h = self.patch_shape[1]
# the so-called patch_shape is the patch shape during pre-training
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, position_embedding=None, **kwargs):
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
1, 2)
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
x = x + position_embedding
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, training_window_size):
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias
class BEiT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=[224, 224],
patch_size=16,
in_chans=3,
num_classes=80,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=None,
init_values=None,
use_abs_pos_emb=False,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_checkpoint=True,
pretrained=None,
out_features=None,
):
super(BEiT, self).__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_checkpoint = use_checkpoint
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# trunc_normal_(self.mask_token, std=.02)
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
logger = get_root_logger()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self,
filename=self.init_cfg['checkpoint'],
strict=False,
logger=logger,
beit_spec_expand_rel_pos = self.use_rel_pos_bias,
)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
# Hp, Wp are HW for patches
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.pos_embed is not None:
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
features = []
training_window_size = torch.tensor([Hp, Wp])
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
else:
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
if i in self.out_indices:
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def beit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def beit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.1,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=1e-5,
**kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
model = model.to("cuda:0")
input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
output1 = model(input1)
output2 = model(input2)
output3 = model(input3)
print("all done")
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import warnings
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
from functools import partial
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w, self.num_patches_h = self.window_size
self.num_patches = self.window_size[0] * self.window_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(
1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
model_name='vit_base_patch16_224',
img_size=384,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.1,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_cfg=None,
pos_embed_interp=False,
random_init=False,
align_corners=False,
use_checkpoint=False,
num_extra_tokens=1,
out_features=None,
**kwargs,
):
super(ViT, self).__init__()
self.model_name = model_name
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.num_classes = num_classes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.hybrid_backbone = hybrid_backbone
self.norm_layer = norm_layer
self.norm_cfg = norm_cfg
self.pos_embed_interp = pos_embed_interp
self.random_init = random_init
self.align_corners = align_corners
self.use_checkpoint = use_checkpoint
self.num_extra_tokens = num_extra_tokens
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
# self.num_stages = self.depth
# self.out_indices = tuple(range(self.num_stages))
if self.hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
if self.num_extra_tokens == 2:
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=self.drop_rate)
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
for i in range(self.depth)])
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.num_extra_tokens==2:
trunc_normal_(self.dist_token, std=0.2)
self.apply(self._init_weights)
# self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
logger = get_root_logger()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _conv_filter(self, state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def to_2D(self, x):
n, hw, c = x.shape
h = w = int(math.sqrt(hw))
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def to_1D(self, x):
n, c, h, w = x.shape
x = x.reshape(n, c, -1).transpose(1, 2)
return x
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_extra_tokens
N = self.pos_embed.shape[1] - self.num_extra_tokens
if npatch == N and w == h:
return self.pos_embed
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
all_tokens = [self.cls_token.expand(B, -1, -1)]
if self.num_extra_tokens == 2:
dist_tokens = self.dist_token.expand(B, -1, -1)
all_tokens.append(dist_tokens)
all_tokens.append(x)
x = torch.cat(all_tokens, dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
# print(f"==========shape of x is {x.shape}==========")
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if i in self.out_indices:
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def deit_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=2,
**kwargs)
model.default_cfg = _cfg()
return model
def mae_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=1,
**kwargs)
model.default_cfg = _cfg()
return model
\ No newline at end of file
from .models import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
'''
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
'''
import json
import os
from pathlib import Path
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{park2019cord,
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
year={2019}
}
"""
_DESCRIPTION = """\
https://github.com/clovaai/cord/
"""
def quad_to_box(quad):
# test 87 is wrongly annotated
box = (
max(0, quad["x1"]),
max(0, quad["y1"]),
quad["x3"],
quad["y3"]
)
if box[3] < box[1]:
bbox = list(box)
tmp = bbox[3]
bbox[3] = bbox[1]
bbox[1] = tmp
box = tuple(bbox)
if box[2] < box[0]:
bbox = list(box)
tmp = bbox[2]
bbox[2] = bbox[0]
bbox[0] = tmp
box = tuple(bbox)
return box
def _get_drive_url(url):
base_url = 'https://drive.google.com/uc?id='
split_url = url.split('/')
return base_url + split_url[5]
_URLS = [
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
# If you failed to download the dataset through the automatic downloader,
# you can download it manually and modify the code to get the local dataset.
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
]
class CordConfig(datasets.BuilderConfig):
"""BuilderConfig for CORD"""
def __init__(self, **kwargs):
"""BuilderConfig for CORD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CordConfig, self).__init__(**kwargs)
class Cord(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
citation=_CITATION,
homepage="https://github.com/clovaai/cord/",
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
"""Uses local files located with data_dir"""
downloaded_file = dl_manager.download_and_extract(_URLS)
# move files from the second URL together with files from the first one.
dest = Path(downloaded_file[0])/"CORD"
for split in ["train", "dev", "test"]:
for file_type in ["image", "json"]:
if split == "test" and file_type == "json":
continue
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
for f in files:
os.rename(f, dest/split/file_type/f.name)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "json")
img_dir = os.path.join(filepath, "image")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["valid_line"]:
cur_line_bboxes = []
line_words, label = item["words"], item["category"]
line_words = [w for w in line_words if w["text"].strip() != ""]
if len(line_words) == 0:
continue
if label == "other":
for w in line_words:
words.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
else:
words.append(line_words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
for w in line_words[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
bboxes.extend(cur_line_bboxes)
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizerBase
from transformers.data.data_collator import (
DataCollatorMixin,
_torch_collate_batch,
)
from transformers.file_utils import PaddingStrategy
from typing import NewType
InputDataClass = NewType("InputDataClass", Any)
def pre_calc_rel_mat(segment_ids):
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
device=segment_ids.device, dtype=torch.bool)
for i in range(segment_ids.shape[0]):
for j in range(segment_ids.shape[1]):
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
return valid_span
@dataclass
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
images = None
if "images" in features[0]:
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if images is not None:
batch["images"] = images
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
for k, v in batch.items()}
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
if labels is None:
return batch
has_bbox_input = "bbox" in features[0]
has_position_input = "position_ids" in features[0]
padding_idx=self.tokenizer.pad_token_id
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
for position_id in batch["position_ids"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+ position_id for position_id in batch["position_ids"]]
if 'segment_ids' in batch:
assert 'position_ids' in batch
for i in range(len(batch['segment_ids'])):
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if 'segment_ids' in batch:
valid_span = pre_calc_rel_mat(
segment_ids=batch['segment_ids']
)
batch['valid_span'] = valid_span
del batch['segment_ids']
if images is not None:
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
return batch
# coding=utf-8
'''
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
'''
import json
import os
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""
class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""
def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)
class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""
BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
cur_line_bboxes = []
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(w["box"], size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
# box = normalize_bbox(item["box"], size)
# cur_line_bboxes = [box for _ in range(len(words))]
bboxes.extend(cur_line_bboxes)
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
\ No newline at end of file
import torchvision.transforms.functional as F
import warnings
import math
import random
import numpy as np
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
def crop(image, i, j, h, w, boxes=None):
cropped_image = F.crop(image, i, j, h, w)
if boxes is not None:
# Currently we cannot use this case since when some boxes is out of the cropped image,
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
# which haven't been implemented here
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
boxes = cropped_boxes.reshape(-1, 4)
return cropped_image, boxes
def resize(image, size, interpolation, boxes=None):
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
# which is compatible with a square image size of 224x224
rescaled_image = F.resize(image, size, interpolation)
if boxes is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
# boxes = boxes.copy()
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
return rescaled_image, scaled_boxes
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def get_bb(bb, page_size):
bbs = [float(j) for j in bb]
xs, ys = [], []
for i, b in enumerate(bbs):
if i % 2 == 0:
xs.append(b)
else:
ys.append(b)
(width, height) = page_size
return_bb = [
clamp(min(xs), 0, width - 1),
clamp(min(ys), 0, height - 1),
clamp(max(xs), 0, width - 1),
clamp(max(ys), 0, height - 1),
]
return_bb = [
int(1000 * return_bb[0] / width),
int(1000 * return_bb[1] / height),
int(1000 * return_bb[2] / width),
int(1000 * return_bb[3] / height),
]
return return_bb
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return F.InterpolationMode.BICUBIC
elif method == 'lanczos':
return F.InterpolationMode.LANCZOS
elif method == 'hamming':
return F.InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return F.InterpolationMode.BILINEAR
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, augmentation=False, box=None):
for t in self.transforms:
img = t(img, augmentation, box)
return img
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, augmentation=False, box=None):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
if augmentation:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.crop(img, i, j, h, w)
# img, box = crop(img, i, j, h, w, box)
img = F.resize(img, self.size, self.interpolation)
second_img = F.resize(img, self.second_size, self.second_interpolation) \
if self.second_size is not None else None
return img, second_img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
import os
import json
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
XFund_label2ids = {
"O":0,
'B-HEADER':1,
'I-HEADER':2,
'B-QUESTION':3,
'I-QUESTION':4,
'B-ANSWER':5,
'I-ANSWER':6,
}
class xfund_dataset(Dataset):
def box_norm(self, box, width, height):
def clip(min_num, num, max_num):
return min(max(num, min_num), max_num)
x0, y0, x1, y1 = box
x0 = clip(0, int((x0 / width) * 1000), 1000)
y0 = clip(0, int((y0 / height) * 1000), 1000)
x1 = clip(0, int((x1 / width) * 1000), 1000)
y1 = clip(0, int((y1 / height) * 1000), 1000)
assert x1 >= x0
assert y1 >= y0
return [x0, y0, x1, y1]
def get_segment_ids(self, bboxs):
segment_ids = []
for i in range(len(bboxs)):
if i == 0:
segment_ids.append(0)
else:
if bboxs[i - 1] == bboxs[i]:
segment_ids.append(segment_ids[-1])
else:
segment_ids.append(segment_ids[-1] + 1)
return segment_ids
def get_position_ids(self, segment_ids):
position_ids = []
for i in range(len(segment_ids)):
if i == 0:
position_ids.append(2)
else:
if segment_ids[i] == segment_ids[i - 1]:
position_ids.append(position_ids[-1] + 1)
else:
position_ids.append(2)
return position_ids
def load_data(
self,
data_file,
):
# re-org data format
total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
for i in range(len(data_file['documents'])):
width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
'height']
cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
for j in range(len(data_file['documents'][i]['document'])):
cur_item = data_file['documents'][i]['document'][j]
cur_doc_lines.append(cur_item['text'])
cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
cur_doc_ner_tags.append(cur_item['label'])
total_data['id'] += [len(total_data['id'])]
total_data['lines'] += [cur_doc_lines]
total_data['bboxes'] += [cur_doc_bboxes]
total_data['ner_tags'] += [cur_doc_ner_tags]
total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
# tokenize text and get bbox/label
total_input_ids, total_bboxs, total_label_ids = [], [], []
for i in range(len(total_data['lines'])):
cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
for j in range(len(total_data['lines'][i])):
cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
if len(cur_input_ids) == 0: continue
cur_label = total_data['ner_tags'][i][j].upper()
if cur_label == 'OTHER':
cur_labels = ["O"] * len(cur_input_ids)
for k in range(len(cur_labels)):
cur_labels[k] = self.label2ids[cur_labels[k]]
else:
cur_labels = [cur_label] * len(cur_input_ids)
cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
for k in range(1, len(cur_labels)):
cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
cur_doc_input_ids += cur_input_ids
cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
cur_doc_labels += cur_labels
assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
assert len(cur_doc_input_ids) > 0
total_input_ids.append(cur_doc_input_ids)
total_bboxs.append(cur_doc_bboxs)
total_label_ids.append(cur_doc_labels)
assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
# split text to several slices because of over-length
input_ids, bboxs, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
for i in range(len(total_input_ids)):
start = 0
cur_iter = 0
while start < len(total_input_ids[i]):
end = min(start + 510, len(total_input_ids[i]))
input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
labels.append([-100] + total_label_ids[i][start: end] + [-100])
cur_segment_ids = self.get_segment_ids(bboxs[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
start = end
cur_iter += 1
assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)
res = {
'input_ids': input_ids,
'bbox': bboxs,
'labels': labels,
'segment_ids': segment_ids,
'position_ids': position_ids,
'image_path': image_path,
}
return res
def __init__(
self,
args,
tokenizer,
mode
):
self.args = args
self.mode = mode
self.cur_la = args.language
self.tokenizer = tokenizer
self.label2ids = XFund_label2ids
self.common_transform = Compose([
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, interpolation=args.train_interpolation,
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor((0.5, 0.5, 0.5)),
std=torch.tensor((0.5, 0.5, 0.5)))
])
data_file = json.load(
open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
'r'))
self.feature = self.load_data(data_file)
def __len__(self):
return len(self.feature['input_ids'])
def __getitem__(self, index):
input_ids = self.feature["input_ids"][index]
# attention_mask = self.feature["attention_mask"][index]
attention_mask = [1] * len(input_ids)
labels = self.feature["labels"][index]
bbox = self.feature["bbox"][index]
segment_ids = self.feature['segment_ids'][index]
position_ids = self.feature['position_ids'][index]
img = pil_loader(self.feature['image_path'][index])
for_patches, _ = self.common_transform(img, augmentation=False)
patch = self.patch_transform(for_patches)
assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"bbox": bbox,
"segment_ids": segment_ids,
"position_ids": position_ids,
"images": patch,
}
return res
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
\ No newline at end of file
from .layoutlmv3 import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
from .configuration_layoutlmv3 import LayoutLMv3Config
from .modeling_layoutlmv3 import (
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Model,
)
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
#AutoConfig.register("layoutlmv3", LayoutLMv3Config)
#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
#AutoTokenizer.register(
# LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
#)
SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})
# coding=utf-8
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
"layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json",
# See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
}
class LayoutLMv3Config(BertConfig):
model_type = "layoutlmv3"
def __init__(
self,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
max_2d_position_embeddings=1024,
coordinate_size=None,
shape_size=None,
has_relative_attention_bias=False,
rel_pos_bins=32,
max_rel_pos=128,
has_spatial_attention_bias=False,
rel_2d_pos_bins=64,
max_rel_2d_pos=256,
visual_embed=True,
mim=False,
wpa_task=False,
discrete_vae_weight_path='',
discrete_vae_type='dall-e',
input_size=224,
second_input_size=112,
device='cuda',
**kwargs
):
"""Constructs RobertaConfig."""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.max_2d_position_embeddings = max_2d_position_embeddings
self.coordinate_size = coordinate_size
self.shape_size = shape_size
self.has_relative_attention_bias = has_relative_attention_bias
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.has_spatial_attention_bias = has_spatial_attention_bias
self.rel_2d_pos_bins = rel_2d_pos_bins
self.max_rel_2d_pos = max_rel_2d_pos
self.visual_embed = visual_embed
self.mim = mim
self.wpa_task = wpa_task
self.discrete_vae_weight_path = discrete_vae_weight_path
self.discrete_vae_type = discrete_vae_type
self.input_size = input_size
self.second_input_size = second_input_size
self.device = device
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LayoutLMv3 model. """
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import apply_chunking_to_forward
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
TokenClassifierOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.models.roberta.modeling_roberta import (
RobertaIntermediate,
RobertaLMHead,
RobertaOutput,
RobertaSelfOutput,
)
from transformers.utils import logging
from .configuration_layoutlmv3 import LayoutLMv3Config
from timm.models.layers import to_2tuple
logger = logging.get_logger(__name__)
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# The following variables are used in detection mycheckpointer.py
self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.num_patches_w = self.patch_shape[0]
self.num_patches_h = self.patch_shape[1]
def forward(self, x, position_embedding=None):
x = self.proj(x)
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2)
Hp, Wp = x.shape[2], x.shape[3]
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
x = x + position_embedding
x = x.flatten(2).transpose(1, 2)
return x
class LayoutLMv3Embeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
# End copy
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
def _calc_spatial_position_embeddings(self, bbox):
try:
assert torch.all(0 <= bbox) and torch.all(bbox <= 1023)
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
except IndexError as e:
raise IndexError("The :obj:`bbox` coordinate values should be within 0-1000 range.") from e
h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
# below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
spatial_position_embeddings = torch.cat(
[
left_position_embeddings,
upper_position_embeddings,
right_position_embeddings,
lower_position_embeddings,
h_position_embeddings,
w_position_embeddings,
],
dim=-1,
)
return spatial_position_embeddings
def create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def forward(
self,
input_ids=None,
bbox=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
past_key_values_length=0,
):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(
input_ids, self.padding_idx, past_key_values_length).to(input_ids.device)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
spatial_position_embeddings = self._calc_spatial_position_embeddings(bbox)
embeddings = embeddings + spatial_position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
"""
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
Args:
inputs_embeds: torch.Tensor≈
Returns: torch.Tensor
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
)
return position_ids.unsqueeze(0).expand(input_shape)
class LayoutLMv3PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LayoutLMv3Config
base_model_prefix = "layoutlmv3"
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class LayoutLMv3SelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def cogview_attn(self, attention_scores, alpha=32):
'''
https://arxiv.org/pdf/2105.13290.pdf
Section 2.4 Stabilization of training: Precision Bottleneck Relaxation (PB-Relax).
A replacement of the original nn.Softmax(dim=-1)(attention_scores)
Seems the new attention_probs will result in a slower speed and a little bias
Can use torch.allclose(standard_attention_probs, cogview_attention_probs, atol=1e-08) for comparison
The smaller atol (e.g., 1e-08), the better.
'''
scaled_attention_scores = attention_scores / alpha
max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
# max_value = scaled_attention_scores.amax(dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)
new_attention_scores = (scaled_attention_scores - max_value) * alpha
return nn.Softmax(dim=-1)(new_attention_scores)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
# The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
# Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
if self.has_relative_attention_bias and self.has_spatial_attention_bias:
attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
elif self.has_relative_attention_bias:
attention_scores += rel_pos / math.sqrt(self.attention_head_size)
# if self.has_relative_attention_bias:
# attention_scores += rel_pos
# if self.has_spatial_attention_bias:
# attention_scores += rel_2d_pos
# attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
# attention_probs = nn.Softmax(dim=-1)(attention_scores) # comment the line below and use this line for speedup
attention_probs = self.cogview_attn(attention_scores) # to stablize training
# assert torch.allclose(attention_probs, nn.Softmax(dim=-1)(attention_scores), atol=1e-8)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class LayoutLMv3Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = LayoutLMv3SelfAttention(config)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class LayoutLMv3Layer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = LayoutLMv3Attention(config)
assert not config.is_decoder and not config.add_cross_attention, \
"This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder."
self.intermediate = RobertaIntermediate(config)
self.output = RobertaOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class LayoutLMv3Encoder(nn.Module):
def __init__(self, config, detection=False, out_features=None):
super().__init__()
self.config = config
self.detection = detection
self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
if self.has_relative_attention_bias:
self.rel_pos_bins = config.rel_pos_bins
self.max_rel_pos = config.max_rel_pos
self.rel_pos_onehot_size = config.rel_pos_bins
self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
if self.has_spatial_attention_bias:
self.max_rel_2d_pos = config.max_rel_2d_pos
self.rel_2d_pos_bins = config.rel_2d_pos_bins
self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
if self.detection:
self.gradient_checkpointing = True
embed_dim = self.config.hidden_size
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
ret = 0
if bidirectional:
num_buckets //= 2
ret += (relative_position > 0).long() * num_buckets
n = torch.abs(relative_position)
else:
n = torch.max(-relative_position, torch.zeros_like(relative_position))
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def _cal_1d_pos_emb(self, hidden_states, position_ids, valid_span):
VISUAL_NUM = 196 + 1
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
if valid_span is not None:
# for the text part, if two words are not in the same line,
# set their distance to the max value (position_ids.shape[-1])
rel_pos_mat[(rel_pos_mat > 0) & (valid_span == False)] = position_ids.shape[1]
rel_pos_mat[(rel_pos_mat < 0) & (valid_span == False)] = -position_ids.shape[1]
# image-text, minimum distance
rel_pos_mat[:, -VISUAL_NUM:, :-VISUAL_NUM] = 0
rel_pos_mat[:, :-VISUAL_NUM, -VISUAL_NUM:] = 0
rel_pos = self.relative_position_bucket(
rel_pos_mat,
num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos,
)
rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
rel_pos = rel_pos.contiguous()
return rel_pos
def _cal_2d_pos_emb(self, hidden_states, bbox):
position_coord_x = bbox[:, :, 0]
position_coord_y = bbox[:, :, 3]
rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
rel_pos_x = self.relative_position_bucket(
rel_pos_x_2d_mat,
num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos,
)
rel_pos_y = self.relative_position_bucket(
rel_pos_y_2d_mat,
num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos,
)
rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
rel_pos_x = rel_pos_x.contiguous()
rel_pos_y = rel_pos_y.contiguous()
rel_2d_pos = rel_pos_x + rel_pos_y
return rel_2d_pos
def forward(
self,
hidden_states,
bbox=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
position_ids=None,
Hp=None,
Wp=None,
valid_span=None,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids, valid_span) if self.has_relative_attention_bias else None
rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
if self.detection:
feat_out = {}
j = 0
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
# return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
# The above line will cause error:
# RuntimeError: Trying to backward through the graph a second time
# (or directly access saved tensors after they have already been freed).
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
rel_pos,
rel_2d_pos
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if self.detection and i in self.out_indices:
xp = hidden_states[:, -Hp*Wp:, :].permute(0, 2, 1).reshape(len(hidden_states), -1, Hp, Wp)
feat_out[self.out_features[j]] = self.ops[j](xp.contiguous())
j += 1
if self.detection:
return feat_out
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
"""
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
def __init__(self, config, detection=False, out_features=None, image_only=False):
super().__init__(config)
self.config = config
assert not config.is_decoder and not config.add_cross_attention, \
"This version do not support decoder. Please refer to RoBERTa for implementation of is_decoder."
self.detection = detection
if not self.detection:
self.image_only = False
else:
assert config.visual_embed
self.image_only = image_only
if not self.image_only:
self.embeddings = LayoutLMv3Embeddings(config)
self.encoder = LayoutLMv3Encoder(config, detection=detection, out_features=out_features)
if config.visual_embed:
embed_dim = self.config.hidden_size
# use the default pre-training parameters for fine-tuning (e.g., input_size)
# when the input_size is larger in fine-tuning, we will interpolate the position embedding in forward
self.patch_embed = PatchEmbed(embed_dim=embed_dim)
patch_size = 16
size = int(self.config.input_size / patch_size)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, embed_dim))
self.pos_drop = nn.Dropout(p=0.)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
self._init_visual_bbox(img_size=(size, size))
from functools import partial
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(embed_dim)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def _init_visual_bbox(self, img_size=(14, 14), max_len=1000):
visual_bbox_x = torch.div(torch.arange(0, max_len * (img_size[1] + 1), max_len),
img_size[1], rounding_mode='trunc')
visual_bbox_y = torch.div(torch.arange(0, max_len * (img_size[0] + 1), max_len),
img_size[0], rounding_mode='trunc')
visual_bbox = torch.stack(
[
visual_bbox_x[:-1].repeat(img_size[0], 1),
visual_bbox_y[:-1].repeat(img_size[1], 1).transpose(0, 1),
visual_bbox_x[1:].repeat(img_size[0], 1),
visual_bbox_y[1:].repeat(img_size[1], 1).transpose(0, 1),
],
dim=-1,
).view(-1, 4)
cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
def _calc_visual_bbox(self, device, dtype, bsz): # , img_size=(14, 14), max_len=1000):
visual_bbox = self.visual_bbox.repeat(bsz, 1, 1)
visual_bbox = visual_bbox.to(device).type(dtype)
return visual_bbox
def forward_image(self, x):
if self.detection:
x = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
else:
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.pos_embed is not None and self.detection:
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None and not self.detection:
x = x + self.pos_embed
x = self.pos_drop(x)
x = self.norm(x)
return x
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
valid_span=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
images=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = False
# if input_ids is not None and inputs_embeds is not None:
# raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
if input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif images is not None:
batch_size = len(images)
device = images.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or images")
if not self.image_only:
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if not self.image_only:
if bbox is None:
bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
embedding_output = self.embeddings(
input_ids=input_ids,
bbox=bbox,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
final_bbox = final_position_ids = None
Hp = Wp = None
if images is not None:
patch_size = 16
Hp, Wp = int(images.shape[2] / patch_size), int(images.shape[3] / patch_size)
visual_emb = self.forward_image(images)
if self.detection:
visual_attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)
if self.image_only:
attention_mask = visual_attention_mask
else:
attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
elif self.image_only:
attention_mask = torch.ones((batch_size, visual_emb.shape[1]), dtype=torch.long, device=device)
if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
if self.config.has_spatial_attention_bias:
visual_bbox = self._calc_visual_bbox(device, dtype=torch.long, bsz=batch_size)
if self.image_only:
final_bbox = visual_bbox
else:
final_bbox = torch.cat([bbox, visual_bbox], dim=1)
visual_position_ids = torch.arange(0, visual_emb.shape[1], dtype=torch.long, device=device).repeat(
batch_size, 1)
if self.image_only:
final_position_ids = visual_position_ids
else:
position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
position_ids = position_ids.expand_as(input_ids)
final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
if self.image_only:
embedding_output = visual_emb
else:
embedding_output = torch.cat([embedding_output, visual_emb], dim=1)
embedding_output = self.LayerNorm(embedding_output)
embedding_output = self.dropout(embedding_output)
elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
if self.config.has_spatial_attention_bias:
final_bbox = bbox
if self.config.has_relative_attention_bias:
position_ids = self.embeddings.position_ids[:, :input_shape[1]]
position_ids = position_ids.expand_as(input_ids)
final_position_ids = position_ids
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device)
encoder_outputs = self.encoder(
embedding_output,
bbox=final_bbox,
position_ids=final_position_ids,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Hp=Hp,
Wp=Wp,
valid_span=valid_span,
)
if self.detection:
return encoder_outputs
sequence_output = encoder_outputs[0]
pooled_output = None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class LayoutLMv3ClassificationHead(nn.Module):
"""
Head for sentence-level classification tasks.
Reference: RobertaClassificationHead
"""
def __init__(self, config, pool_feature=False):
super().__init__()
self.pool_feature = pool_feature
if pool_feature:
self.dense = nn.Linear(config.hidden_size*3, config.hidden_size)
else:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, x):
# x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.layoutlmv3 = LayoutLMv3Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
if config.num_labels < 10:
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
else:
self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
self.init_weights()
def forward(
self,
input_ids=None,
bbox=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
valid_span=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
images=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlmv3(
input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
images=images,
valid_span=valid_span,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.layoutlmv3 = LayoutLMv3Model(config)
# self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
valid_span=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
bbox=None,
images=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlmv3(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
bbox=bbox,
images=images,
valid_span=valid_span,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.layoutlmv3 = LayoutLMv3Model(config)
self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
self.init_weights()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
valid_span=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
bbox=None,
images=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.layoutlmv3(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
bbox=bbox,
images=images,
valid_span=valid_span,
)
sequence_output = outputs[0][:, 0, :]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LayoutLMv3, refer to RoBERTa."""
from transformers.models.roberta import RobertaTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
class LayoutLMv3Tokenizer(RobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
# pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization classes for LayoutLMv3, refer to RoBERTa."""
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.utils import logging
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
class LayoutLMv3TokenizerFast(RobertaTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
# pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = LayoutLMv3Tokenizer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment