".circleci/vscode:/vscode.git/clone" did not exist on "fa1aa52dda3b8d7bca48223addd7a8c254bd1d3e"
Commit bd927919 authored by myhloli's avatar myhloli
Browse files

refactor: rename init file and update app.py to enable parsing method

parent f5016508
import random
from loguru import logger
try:
from paddleocr import PPStructure
except ImportError:
logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
exit(1)
def region_to_bbox(region):
x0 = region[0][0]
y0 = region[0][1]
x1 = region[2][0]
y1 = region[2][1]
return [x0, y0, x1, y1]
class CustomPaddleModel:
def __init__(self,
ocr: bool = False,
show_log: bool = False,
lang=None,
det_db_box_thresh=0.3,
use_dilation=True,
det_db_unclip_ratio=1.8
):
if lang is not None:
self.model = PPStructure(table=False,
ocr=True,
show_log=show_log,
lang=lang,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
else:
self.model = PPStructure(table=False,
ocr=True,
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
)
def __call__(self, img):
try:
import cv2
except ImportError:
logger.error("opencv-python not installed, please install by pip.")
exit(1)
# 将RGB图片转换为BGR格式适配paddle
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
result = self.model(img)
spans = []
for line in result:
line.pop("img")
"""
为paddle输出适配type no.
title: 0 # 标题
text: 1 # 文本
header: 2 # abandon
footer: 2 # abandon
reference: 1 # 文本 or abandon
equation: 8 # 行间公式 block
equation: 14 # 行间公式 text
figure: 3 # 图片
figure_caption: 4 # 图片描述
table: 5 # 表格
table_caption: 6 # 表格描述
"""
if line["type"] == "title":
line["category_id"] = 0
elif line["type"] in ["text", "reference"]:
line["category_id"] = 1
elif line["type"] == "figure":
line["category_id"] = 3
elif line["type"] == "figure_caption":
line["category_id"] = 4
elif line["type"] == "table":
line["category_id"] = 5
elif line["type"] == "table_caption":
line["category_id"] = 6
elif line["type"] == "equation":
line["category_id"] = 8
elif line["type"] in ["header", "footer"]:
line["category_id"] = 2
else:
logger.warning(f"unknown type: {line['type']}")
# 兼容不输出score的paddleocr版本
if line.get("score") is None:
line["score"] = 0.5 + random.random() * 0.5
res = line.pop("res", None)
if res is not None and len(res) > 0:
for span in res:
new_span = {
"category_id": 15,
"bbox": region_to_bbox(span["text_region"]),
"score": span["confidence"],
"text": span["text"],
}
spans.append(new_span)
if len(spans) > 0:
result.extend(spans)
return result
# Copyright (c) Opendatalab. All rights reserved.
import os
from pathlib import Path
import yaml
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.data.utils import load_images_from_pdf
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.pdf_check import extract_pages
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
def get_model_config():
local_models_dir = get_local_models_dir()
device = get_device()
current_file_path = os.path.abspath(__file__)
root_dir = Path(current_file_path).parents[3]
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, 'r', encoding='utf-8') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
return root_dir, local_models_dir, device, configs
def get_text_images(simple_images):
_, local_models_dir, device, configs = get_model_config()
atom_model_manager = AtomModelSingleton()
temp_layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(
os.path.join(
local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
)
),
device=device,
)
text_images = []
for simple_image in simple_images:
image = simple_image['img']
layout_res = temp_layout_model.predict(image)
# 给textblock截图
for res in layout_res:
if res['category_id'] in [1]:
x1, y1, _, _, x2, y2, _, _ = res['poly']
# 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100:
continue
text_images.append(image[y1:y2, x1:x2])
return text_images
def auto_detect_lang(pdf_bytes: bytes):
sample_docs = extract_pages(pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
text_images = get_text_images(simple_images)
langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
lang = langdetect_model.do_detect(text_images)
return lang
def model_init(model_name: str):
atom_model_manager = AtomModelSingleton()
if model_name == MODEL_NAME.YOLO_V11_LangDetect:
root_dir, _, device, _ = get_model_config()
model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.LangDetect,
langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
langdetect_model_weight=str(os.path.join(root_dir, 'resources', 'yolov11-langdetect', 'yolo_v11_ft.pt')),
device=device,
)
else:
raise ValueError(f"model_name {model_name} not found")
return model
# Copyright (c) Opendatalab. All rights reserved.
import time
from collections import Counter
from uuid import uuid4
import cv2
import numpy as np
import torch
from loguru import logger
from ultralytics import YOLO
language_dict = {
"ch": "中文简体",
"en": "英语",
"japan": "日语",
"korean": "韩语",
"fr": "法语",
"german": "德语",
"ar": "阿拉伯语",
"ru": "俄语"
}
def split_images(image, result_images=None):
"""
对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
"""
if result_images is None:
result_images = []
height, width = image.shape[:2]
long_side = max(width, height) # 获取较长边长度
if long_side <= 400:
result_images.append(image)
return result_images
new_long_side = long_side // 2
sub_images = []
if width >= height: # 如果宽度是较长边
for x in range(0, width, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width:
continue
sub_image = image[0:height, x:x + new_long_side]
sub_images.append(sub_image)
else: # 如果高度是较长边
for y in range(0, height, new_long_side):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height:
continue
sub_image = image[y:y + new_long_side, 0:width]
sub_images.append(sub_image)
for sub_image in sub_images:
split_images(sub_image, result_images)
return result_images
def resize_images_to_224(image):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
try:
height, width = image.shape[:2]
if width < 224 or height < 224:
# Create black background
new_image = np.zeros((224, 224, 3), dtype=np.uint8)
# Calculate paste position (ensure they're not negative)
paste_x = max(0, (224 - width) // 2)
paste_y = max(0, (224 - height) // 2)
# Make sure we don't exceed the boundaries of new_image
paste_width = min(width, 224)
paste_height = min(height, 224)
# Paste original image onto black background
new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
image = new_image
else:
# Resize using cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
return image
except Exception as e:
logger.exception(f"Error in resize_images_to_224: {e}")
return None
class YOLOv11LangDetModel(object):
def __init__(self, langdetect_model_weight, device):
self.model = YOLO(langdetect_model_weight)
if str(device).startswith("npu"):
self.device = torch.device(device)
else:
self.device = device
def do_detect(self, images: list):
all_images = []
for image in images:
height, width = image.shape[:2]
if width < 100 and height < 100:
continue
temp_images = split_images(image)
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))
# langdetect_start = time.time()
images_lang_res = self.batch_predict(all_images, batch_size=256)
# logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
else:
language = None
return language
def predict(self, image):
results = self.model.predict(image, verbose=False, device=self.device)
predicted_class_id = int(results[0].probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
return predicted_class_name
def batch_predict(self, images: list, batch_size: int) -> list:
images_lang_res = []
for index in range(0, len(images), batch_size):
lang_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index: index + batch_size],
verbose = False,
device=self.device,
)
]
for res in lang_res:
predicted_class_id = int(res.probs.top1)
predicted_class_name = self.model.names[predicted_class_id]
images_lang_res.append(predicted_class_name)
return images_lang_res
\ No newline at end of file
from doclayout_yolo import YOLOv10
from tqdm import tqdm
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
self.device = device
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(
image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
return layout_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
return images_layout_res
# --------------------------------------------------------------------------------
# VIT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# CoaT: https://github.com/mlpc-ucsd/CoaT
# --------------------------------------------------------------------------------
import torch
from detectron2.layers import (
ShapeSpec,
)
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
from .deit import deit_base_patch16, mae_base_patch16
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
from transformers import AutoConfig
__all__ = [
"build_vit_fpn_backbone",
]
class VIT_Backbone(Backbone):
"""
Implement VIT backbone.
"""
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=None, image_only=False, cfg=None):
super().__init__()
self._out_features = out_features
if 'base' in name:
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
else:
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if name == 'beit_base_patch16':
model_func = beit_base_patch16
elif name == 'dit_base_patch16':
model_func = dit_base_patch16
elif name == "deit_base_patch16":
model_func = deit_base_patch16
elif name == "mae_base_patch16":
model_func = mae_base_patch16
elif name == "dit_large_patch16":
model_func = dit_large_patch16
elif name == "beit_large_patch16":
model_func = beit_large_patch16
if 'beit' in name or 'dit' in name:
if pos_type == "abs":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_abs_pos_emb=True,
**model_kwargs)
elif pos_type == "shared_rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_shared_rel_pos_bias=True,
**model_kwargs)
elif pos_type == "rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_rel_pos_bias=True,
**model_kwargs)
else:
raise ValueError()
elif "layoutlmv3" in name:
config = AutoConfig.from_pretrained(config_path)
# disable relative bias as DiT
config.has_spatial_attention_bias = False
config.has_relative_attention_bias = False
self.backbone = LayoutLMv3Model(config, detection=True,
out_features=out_features, image_only=image_only)
else:
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
**model_kwargs)
self.name = name
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
if "layoutlmv3" in self.name:
return self.backbone.forward(
input_ids=x["input_ids"] if "input_ids" in x else None,
bbox=x["bbox"] if "bbox" in x else None,
images=x["images"] if "images" in x else None,
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
# output_hidden_states=True,
)
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
return self.backbone.forward_features(x)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def build_VIT_backbone(cfg):
"""
Create a VIT instance from config.
Args:
cfg: a detectron2 CfgNode
Returns:
A VIT backbone instance.
"""
# fmt: off
name = cfg.MODEL.VIT.NAME
out_features = cfg.MODEL.VIT.OUT_FEATURES
drop_path = cfg.MODEL.VIT.DROP_PATH
img_size = cfg.MODEL.VIT.IMG_SIZE
pos_type = cfg.MODEL.VIT.POS_TYPE
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
if 'layoutlmv3' in name:
if cfg.MODEL.CONFIG_PATH != '':
config_path = cfg.MODEL.CONFIG_PATH
else:
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
else:
config_path = None
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
@BACKBONE_REGISTRY.register()
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Create a VIT w/ FPN backbone.
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_VIT_backbone(cfg)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import warnings
import math
import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.0)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, training_window_size=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None, training_window_size=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
training_window_size=training_window_size))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w = self.patch_shape[0]
self.num_patches_h = self.patch_shape[1]
# the so-called patch_shape is the patch shape during pre-training
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, position_embedding=None, **kwargs):
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
if position_embedding is not None:
# interpolate the position embedding to the corresponding size
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
1, 2)
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
x = x + position_embedding
x = x.flatten(2).transpose(1, 2)
return x, (Hp, Wp)
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, training_window_size):
if training_window_size == self.window_size:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
else:
training_window_size = tuple(training_window_size.tolist())
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
new_relative_position_bias_table = F.interpolate(
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
2 * self.window_size[0] - 1,
2 * self.window_size[1] - 1),
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
align_corners=False)
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
new_num_relative_distance - 3).permute(
1, 0)
new_relative_position_bias_table = torch.cat(
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(training_window_size[0])
coords_w = torch.arange(training_window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += training_window_size[1] - 1
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
relative_position_index = \
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = new_num_relative_distance - 3
relative_position_index[0:, 0] = new_num_relative_distance - 2
relative_position_index[0, 0] = new_num_relative_distance - 1
relative_position_bias = \
new_relative_position_bias_table[relative_position_index.view(-1)].view(
training_window_size[0] * training_window_size[1] + 1,
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias
class BEiT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=[224, 224],
patch_size=16,
in_chans=3,
num_classes=80,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=None,
init_values=None,
use_abs_pos_emb=False,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_checkpoint=True,
pretrained=None,
out_features=None,
):
super(BEiT, self).__init__()
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_checkpoint = use_checkpoint
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# trunc_normal_(self.mask_token, std=.02)
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.SyncBatchNorm(embed_dim),
nn.BatchNorm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
logger = get_root_logger()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self,
filename=self.init_cfg['checkpoint'],
strict=False,
logger=logger,
beit_spec_expand_rel_pos = self.use_rel_pos_bias,
)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
# Hp, Wp are HW for patches
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.pos_embed is not None:
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
features = []
training_window_size = torch.tensor([Hp, Wp])
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
else:
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
if i in self.out_indices:
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def beit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def beit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=None,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_base_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.1,
**kwargs)
model.default_cfg = _cfg()
return model
def dit_large_patch16(pretrained=False, **kwargs):
model = BEiT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=1e-5,
**kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
model = model.to("cuda:0")
input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
output1 = model(input1)
output2 = model(input2)
output3 = model(input3)
print("all done")
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import warnings
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
from functools import partial
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w, self.num_patches_h = self.window_size
self.num_patches = self.window_size[0] * self.window_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(
1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
model_name='vit_base_patch16_224',
img_size=384,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.1,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_cfg=None,
pos_embed_interp=False,
random_init=False,
align_corners=False,
use_checkpoint=False,
num_extra_tokens=1,
out_features=None,
**kwargs,
):
super(ViT, self).__init__()
self.model_name = model_name
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.num_classes = num_classes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.hybrid_backbone = hybrid_backbone
self.norm_layer = norm_layer
self.norm_cfg = norm_cfg
self.pos_embed_interp = pos_embed_interp
self.random_init = random_init
self.align_corners = align_corners
self.use_checkpoint = use_checkpoint
self.num_extra_tokens = num_extra_tokens
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
# self.num_stages = self.depth
# self.out_indices = tuple(range(self.num_stages))
if self.hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
if self.num_extra_tokens == 2:
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=self.drop_rate)
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
for i in range(self.depth)])
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.num_extra_tokens==2:
trunc_normal_(self.dist_token, std=0.2)
self.apply(self._init_weights)
# self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
logger = get_root_logger()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _conv_filter(self, state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def to_2D(self, x):
n, hw, c = x.shape
h = w = int(math.sqrt(hw))
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def to_1D(self, x):
n, c, h, w = x.shape
x = x.reshape(n, c, -1).transpose(1, 2)
return x
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_extra_tokens
N = self.pos_embed.shape[1] - self.num_extra_tokens
if npatch == N and w == h:
return self.pos_embed
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
all_tokens = [self.cls_token.expand(B, -1, -1)]
if self.num_extra_tokens == 2:
dist_tokens = self.dist_token.expand(B, -1, -1)
all_tokens.append(dist_tokens)
all_tokens.append(x)
x = torch.cat(all_tokens, dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
# print(f"==========shape of x is {x.shape}==========")
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if i in self.out_indices:
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def deit_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=2,
**kwargs)
model.default_cfg = _cfg()
return model
def mae_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=1,
**kwargs)
model.default_cfg = _cfg()
return model
\ No newline at end of file
from .models import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
'''
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
'''
import json
import os
from pathlib import Path
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{park2019cord,
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
year={2019}
}
"""
_DESCRIPTION = """\
https://github.com/clovaai/cord/
"""
def quad_to_box(quad):
# test 87 is wrongly annotated
box = (
max(0, quad["x1"]),
max(0, quad["y1"]),
quad["x3"],
quad["y3"]
)
if box[3] < box[1]:
bbox = list(box)
tmp = bbox[3]
bbox[3] = bbox[1]
bbox[1] = tmp
box = tuple(bbox)
if box[2] < box[0]:
bbox = list(box)
tmp = bbox[2]
bbox[2] = bbox[0]
bbox[0] = tmp
box = tuple(bbox)
return box
def _get_drive_url(url):
base_url = 'https://drive.google.com/uc?id='
split_url = url.split('/')
return base_url + split_url[5]
_URLS = [
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
# If you failed to download the dataset through the automatic downloader,
# you can download it manually and modify the code to get the local dataset.
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
]
class CordConfig(datasets.BuilderConfig):
"""BuilderConfig for CORD"""
def __init__(self, **kwargs):
"""BuilderConfig for CORD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CordConfig, self).__init__(**kwargs)
class Cord(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
citation=_CITATION,
homepage="https://github.com/clovaai/cord/",
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
"""Uses local files located with data_dir"""
downloaded_file = dl_manager.download_and_extract(_URLS)
# move files from the second URL together with files from the first one.
dest = Path(downloaded_file[0])/"CORD"
for split in ["train", "dev", "test"]:
for file_type in ["image", "json"]:
if split == "test" and file_type == "json":
continue
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
for f in files:
os.rename(f, dest/split/file_type/f.name)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "json")
img_dir = os.path.join(filepath, "image")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["valid_line"]:
cur_line_bboxes = []
line_words, label = item["words"], item["category"]
line_words = [w for w in line_words if w["text"].strip() != ""]
if len(line_words) == 0:
continue
if label == "other":
for w in line_words:
words.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
else:
words.append(line_words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
for w in line_words[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
bboxes.extend(cur_line_bboxes)
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizerBase
from transformers.data.data_collator import (
DataCollatorMixin,
_torch_collate_batch,
)
from transformers.file_utils import PaddingStrategy
from typing import NewType
InputDataClass = NewType("InputDataClass", Any)
def pre_calc_rel_mat(segment_ids):
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
device=segment_ids.device, dtype=torch.bool)
for i in range(segment_ids.shape[0]):
for j in range(segment_ids.shape[1]):
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
return valid_span
@dataclass
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
images = None
if "images" in features[0]:
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if images is not None:
batch["images"] = images
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
for k, v in batch.items()}
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
if labels is None:
return batch
has_bbox_input = "bbox" in features[0]
has_position_input = "position_ids" in features[0]
padding_idx=self.tokenizer.pad_token_id
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
for position_id in batch["position_ids"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+ position_id for position_id in batch["position_ids"]]
if 'segment_ids' in batch:
assert 'position_ids' in batch
for i in range(len(batch['segment_ids'])):
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if 'segment_ids' in batch:
valid_span = pre_calc_rel_mat(
segment_ids=batch['segment_ids']
)
batch['valid_span'] = valid_span
del batch['segment_ids']
if images is not None:
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
return batch
# coding=utf-8
'''
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
'''
import json
import os
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""
class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""
def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)
class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""
BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
cur_line_bboxes = []
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(w["box"], size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
# box = normalize_bbox(item["box"], size)
# cur_line_bboxes = [box for _ in range(len(words))]
bboxes.extend(cur_line_bboxes)
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
\ No newline at end of file
import torchvision.transforms.functional as F
import warnings
import math
import random
import numpy as np
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
def crop(image, i, j, h, w, boxes=None):
cropped_image = F.crop(image, i, j, h, w)
if boxes is not None:
# Currently we cannot use this case since when some boxes is out of the cropped image,
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
# which haven't been implemented here
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
boxes = cropped_boxes.reshape(-1, 4)
return cropped_image, boxes
def resize(image, size, interpolation, boxes=None):
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
# which is compatible with a square image size of 224x224
rescaled_image = F.resize(image, size, interpolation)
if boxes is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
# boxes = boxes.copy()
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
return rescaled_image, scaled_boxes
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def get_bb(bb, page_size):
bbs = [float(j) for j in bb]
xs, ys = [], []
for i, b in enumerate(bbs):
if i % 2 == 0:
xs.append(b)
else:
ys.append(b)
(width, height) = page_size
return_bb = [
clamp(min(xs), 0, width - 1),
clamp(min(ys), 0, height - 1),
clamp(max(xs), 0, width - 1),
clamp(max(ys), 0, height - 1),
]
return_bb = [
int(1000 * return_bb[0] / width),
int(1000 * return_bb[1] / height),
int(1000 * return_bb[2] / width),
int(1000 * return_bb[3] / height),
]
return return_bb
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return F.InterpolationMode.BICUBIC
elif method == 'lanczos':
return F.InterpolationMode.LANCZOS
elif method == 'hamming':
return F.InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return F.InterpolationMode.BILINEAR
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, augmentation=False, box=None):
for t in self.transforms:
img = t(img, augmentation, box)
return img
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, augmentation=False, box=None):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
if augmentation:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.crop(img, i, j, h, w)
# img, box = crop(img, i, j, h, w, box)
img = F.resize(img, self.size, self.interpolation)
second_img = F.resize(img, self.second_size, self.second_interpolation) \
if self.second_size is not None else None
return img, second_img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
import os
import json
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
XFund_label2ids = {
"O":0,
'B-HEADER':1,
'I-HEADER':2,
'B-QUESTION':3,
'I-QUESTION':4,
'B-ANSWER':5,
'I-ANSWER':6,
}
class xfund_dataset(Dataset):
def box_norm(self, box, width, height):
def clip(min_num, num, max_num):
return min(max(num, min_num), max_num)
x0, y0, x1, y1 = box
x0 = clip(0, int((x0 / width) * 1000), 1000)
y0 = clip(0, int((y0 / height) * 1000), 1000)
x1 = clip(0, int((x1 / width) * 1000), 1000)
y1 = clip(0, int((y1 / height) * 1000), 1000)
assert x1 >= x0
assert y1 >= y0
return [x0, y0, x1, y1]
def get_segment_ids(self, bboxs):
segment_ids = []
for i in range(len(bboxs)):
if i == 0:
segment_ids.append(0)
else:
if bboxs[i - 1] == bboxs[i]:
segment_ids.append(segment_ids[-1])
else:
segment_ids.append(segment_ids[-1] + 1)
return segment_ids
def get_position_ids(self, segment_ids):
position_ids = []
for i in range(len(segment_ids)):
if i == 0:
position_ids.append(2)
else:
if segment_ids[i] == segment_ids[i - 1]:
position_ids.append(position_ids[-1] + 1)
else:
position_ids.append(2)
return position_ids
def load_data(
self,
data_file,
):
# re-org data format
total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
for i in range(len(data_file['documents'])):
width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
'height']
cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
for j in range(len(data_file['documents'][i]['document'])):
cur_item = data_file['documents'][i]['document'][j]
cur_doc_lines.append(cur_item['text'])
cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
cur_doc_ner_tags.append(cur_item['label'])
total_data['id'] += [len(total_data['id'])]
total_data['lines'] += [cur_doc_lines]
total_data['bboxes'] += [cur_doc_bboxes]
total_data['ner_tags'] += [cur_doc_ner_tags]
total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
# tokenize text and get bbox/label
total_input_ids, total_bboxs, total_label_ids = [], [], []
for i in range(len(total_data['lines'])):
cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
for j in range(len(total_data['lines'][i])):
cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
if len(cur_input_ids) == 0: continue
cur_label = total_data['ner_tags'][i][j].upper()
if cur_label == 'OTHER':
cur_labels = ["O"] * len(cur_input_ids)
for k in range(len(cur_labels)):
cur_labels[k] = self.label2ids[cur_labels[k]]
else:
cur_labels = [cur_label] * len(cur_input_ids)
cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
for k in range(1, len(cur_labels)):
cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
cur_doc_input_ids += cur_input_ids
cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
cur_doc_labels += cur_labels
assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
assert len(cur_doc_input_ids) > 0
total_input_ids.append(cur_doc_input_ids)
total_bboxs.append(cur_doc_bboxs)
total_label_ids.append(cur_doc_labels)
assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
# split text to several slices because of over-length
input_ids, bboxs, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
for i in range(len(total_input_ids)):
start = 0
cur_iter = 0
while start < len(total_input_ids[i]):
end = min(start + 510, len(total_input_ids[i]))
input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
labels.append([-100] + total_label_ids[i][start: end] + [-100])
cur_segment_ids = self.get_segment_ids(bboxs[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
start = end
cur_iter += 1
assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)
res = {
'input_ids': input_ids,
'bbox': bboxs,
'labels': labels,
'segment_ids': segment_ids,
'position_ids': position_ids,
'image_path': image_path,
}
return res
def __init__(
self,
args,
tokenizer,
mode
):
self.args = args
self.mode = mode
self.cur_la = args.language
self.tokenizer = tokenizer
self.label2ids = XFund_label2ids
self.common_transform = Compose([
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, interpolation=args.train_interpolation,
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor((0.5, 0.5, 0.5)),
std=torch.tensor((0.5, 0.5, 0.5)))
])
data_file = json.load(
open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
'r'))
self.feature = self.load_data(data_file)
def __len__(self):
return len(self.feature['input_ids'])
def __getitem__(self, index):
input_ids = self.feature["input_ids"][index]
# attention_mask = self.feature["attention_mask"][index]
attention_mask = [1] * len(input_ids)
labels = self.feature["labels"][index]
bbox = self.feature["bbox"][index]
segment_ids = self.feature['segment_ids'][index]
position_ids = self.feature['position_ids'][index]
img = pil_loader(self.feature['image_path'][index])
for_patches, _ = self.common_transform(img, augmentation=False)
patch = self.patch_transform(for_patches)
assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"bbox": bbox,
"segment_ids": segment_ids,
"position_ids": position_ids,
"images": patch,
}
return res
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
\ No newline at end of file
from .layoutlmv3 import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
from .configuration_layoutlmv3 import LayoutLMv3Config
from .modeling_layoutlmv3 import (
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Model,
)
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
#AutoConfig.register("layoutlmv3", LayoutLMv3Config)
#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
#AutoTokenizer.register(
# LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
#)
SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment