Commit 1fac6aa7 authored by myhloli's avatar myhloli
Browse files

update:Integrate the PDF-Extract-Kit inside

parent 4703503b
...@@ -85,6 +85,7 @@ def do_parse( ...@@ -85,6 +85,7 @@ def do_parse(
orig_model_list = copy.deepcopy(model_list) orig_model_list = copy.deepcopy(model_list)
local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method) local_image_dir, local_md_dir = prepare_env(pdf_file_name, parse_method)
logger.info(f"local output dir is {local_md_dir}")
image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir) image_writer, md_writer = DiskReaderWriter(local_image_dir), DiskReaderWriter(local_md_dir)
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
......
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL, MODEL_TYPE
import magic_pdf.model as model_config import magic_pdf.model as model_config
...@@ -34,8 +34,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -34,8 +34,8 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
pm = page.get_pixmap(matrix=mat, alpha=False) pm = page.get_pixmap(matrix=mat, alpha=False)
# if width or height > 3000 pixels, don't enlarge the image # if width or height > 3000 pixels, don't enlarge the image
if pix.width > 3000 or pix.height > 3000: if pm.width > 3000 or pm.height > 3000:
pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
img = np.array(img) img = np.array(img)
...@@ -44,31 +44,36 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list: ...@@ -44,31 +44,36 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
return images return images
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.Paddle): def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, model=MODEL.Paddle,
model_type=MODEL_TYPE.SINGLE_PAGE):
custom_model = None
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log)
else:
logger.error("Not allow model_name!")
exit(1)
else: else:
logger.error("use_inside_model is False, not allow to use inside model") logger.error("use_inside_model is False, not allow to use inside model")
exit(1) exit(1)
images = load_images_from_pdf(pdf_bytes) images = load_images_from_pdf(pdf_bytes)
custom_model = None
if model == MODEL.Paddle:
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
else:
pass
model_json = []
for index, img_dict in enumerate(images):
img = img_dict["img"]
page_width = img_dict["width"]
page_height = img_dict["height"]
result = custom_model(img)
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
# @todo 把公式识别放在后置位置,待整本全部模型结果出来之后再补公式数据 model_json = []
if model_type == MODEL_TYPE.SINGLE_PAGE:
for index, img_dict in enumerate(images):
img = img_dict["img"]
page_width = img_dict["width"]
page_height = img_dict["height"]
result = custom_model(img)
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict)
elif model_type == MODEL_TYPE.MULTI_PAGE:
model_json = custom_model(images)
return model_json return model_json
class MODEL: class MODEL:
Paddle = "pp_structure_v2" Paddle = "pp_structure_v2"
PEK = "pdf_extract_kit"
class MODEL_TYPE:
# 单页解析
SINGLE_PAGE = 1
# 多页解析
MULTI_PAGE = 2
import os import os
import time
import cv2
import fitz
import numpy as np import numpy as np
import torch
import unimernet.tasks as tasks
import yaml import yaml
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from ultralytics import YOLO from ultralytics import YOLO
from loguru import logger
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
import argparse
from torchvision import transforms
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
class CustomPEKModel: def layout_model_init(weight, config_file):
def __init__(self, ocr: bool = False, show_log: bool = False): model = Layoutlmv3_Predictor(weight, config_file)
## ======== model init ========## return model
with open('configs/model_configs.yaml') as f:
model_configs = yaml.load(f, Loader=yaml.FullLoader)
img_size = model_configs['model_args']['img_size']
conf_thres = model_configs['model_args']['conf_thres']
iou_thres = model_configs['model_args']['iou_thres']
device = model_configs['model_args']['device']
dpi = model_configs['model_args']['pdf_dpi']
mfd_model = mfd_model_init(model_configs['model_args']['mfd_weight'])
mfr_model, mfr_vis_processors = mfr_model_init(model_configs['model_args']['mfr_weight'], device=device)
mfr_transform = transforms.Compose([mfr_vis_processors, ])
layout_model = layout_model_init(model_configs['model_args']['layout_weight'])
ocr_model = ModifiedPaddleOCR(show_log=True)
print(now.strftime('%Y-%m-%d %H:%M:%S'))
print('Model init done!')
## ======== model init ========##
def __call__(self, image):
# layout检测 + 公式检测 def mfr_model_init(weight_dir, cfg_path, device='cpu'):
doc_layout_result = [] args = argparse.Namespace(cfg_path=cfg_path, options=None)
latex_filling_list = [] cfg = Config(args)
mf_image_list = [] cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model = model.to(device)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
img_H, img_W = image.shape[0], image.shape[1]
layout_res = layout_model(image, ignore_catids=[])
# 公式检测
mfd_res = mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0]
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res['layout_dets'].append(new_item)
latex_filling_list.append(new_item)
bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
mf_image_list.append(bbox_img)
layout_res['page_info'] = dict( class CustomPEKModel:
page_no=idx, def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
height=img_H, """
width=img_W ======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model)
current_dir = os.path.dirname(current_file_path)
# 上一级目录(magic_pdf)
root_dir = os.path.dirname(current_dir)
# model_config目录
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r") as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
self.apply_ocr = ocr
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
self.apply_layout, self.apply_formula, self.apply_ocr
) )
doc_layout_result.append(layout_res) )
assert self.apply_layout, "DocAnalysis must contain layout model."
# 初始化解析方案
self.device = self.configs["config"]["device"]
logger.info("using device: {}".format(self.device))
# 初始化layout模型
self.layout_model = layout_model_init(
os.path.join(root_dir, self.configs['weights']['layout']),
os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")
)
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"])))
# 初始化公式解析模型
mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
self.mfr_model, mfr_vis_processors = mfr_model_init(
os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path,
device=self.device)
self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
# 初始化ocr
if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
# 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。 logger.info('DocAnalysis init done!')
a = time.time()
dataset = MathDataset(mf_image_list, transform=mfr_transform)
dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
mfr_res = []
gpu_total_cost = 0
for imgs in dataloader:
imgs = imgs.to(device)
gpu_start = time.time()
output = mfr_model.generate({'image': imgs})
gpu_cost = time.time() - gpu_start
gpu_total_cost += gpu_cost
print(f"gpu_cost: {gpu_cost}")
mfr_res.extend(output['pred_str'])
print(f"gpu_total_cost: {gpu_total_cost}")
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
b = time.time()
print("formula nums:", len(mf_image_list), "mfr time:", round(b - a, 2))
# ocr识别
for idx, image in enumerate(img_list):
pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
single_page_res = doc_layout_result[idx]['layout_dets']
single_page_mfdetrec_res = []
for res in single_page_res:
if int(res['category_id']) in [13, 14]:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
single_page_mfdetrec_res.append({
"bbox": [xmin, ymin, xmax, ymax],
})
for res in single_page_res:
if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
crop_box = [xmin, ymin, xmax, ymax]
cropped_img = Image.new('RGB', pil_img.size, 'white')
cropped_img.paste(pil_img.crop(crop_box), crop_box)
cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
doc_layout_result[idx]['layout_dets'].append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
output_dir = args.output def __call__(self, image):
os.makedirs(output_dir, exist_ok=True) pass
basename = os.path.basename(single_pdf)[0:-4]
with open(os.path.join(output_dir, f'{basename}.json'), 'w') as f:
json.dump(doc_layout_result, f)
\ No newline at end of file
# --------------------------------------------------------------------------------
# VIT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# CoaT: https://github.com/mlpc-ucsd/CoaT
# --------------------------------------------------------------------------------
import torch
from detectron2.layers import (
ShapeSpec,
)
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
from .deit import deit_base_patch16, mae_base_patch16
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
from transformers import AutoConfig
__all__ = [
"build_vit_fpn_backbone",
]
class VIT_Backbone(Backbone):
"""
Implement VIT backbone.
"""
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=None, image_only=False, cfg=None):
super().__init__()
self._out_features = out_features
if 'base' in name:
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
else:
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if name == 'beit_base_patch16':
model_func = beit_base_patch16
elif name == 'dit_base_patch16':
model_func = dit_base_patch16
elif name == "deit_base_patch16":
model_func = deit_base_patch16
elif name == "mae_base_patch16":
model_func = mae_base_patch16
elif name == "dit_large_patch16":
model_func = dit_large_patch16
elif name == "beit_large_patch16":
model_func = beit_large_patch16
if 'beit' in name or 'dit' in name:
if pos_type == "abs":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_abs_pos_emb=True,
**model_kwargs)
elif pos_type == "shared_rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_shared_rel_pos_bias=True,
**model_kwargs)
elif pos_type == "rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_rel_pos_bias=True,
**model_kwargs)
else:
raise ValueError()
elif "layoutlmv3" in name:
config = AutoConfig.from_pretrained(config_path)
# disable relative bias as DiT
config.has_spatial_attention_bias = False
config.has_relative_attention_bias = False
self.backbone = LayoutLMv3Model(config, detection=True,
out_features=out_features, image_only=image_only)
else:
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
**model_kwargs)
self.name = name
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
if "layoutlmv3" in self.name:
return self.backbone.forward(
input_ids=x["input_ids"] if "input_ids" in x else None,
bbox=x["bbox"] if "bbox" in x else None,
images=x["images"] if "images" in x else None,
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
# output_hidden_states=True,
)
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
return self.backbone.forward_features(x)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def build_VIT_backbone(cfg):
"""
Create a VIT instance from config.
Args:
cfg: a detectron2 CfgNode
Returns:
A VIT backbone instance.
"""
# fmt: off
name = cfg.MODEL.VIT.NAME
out_features = cfg.MODEL.VIT.OUT_FEATURES
drop_path = cfg.MODEL.VIT.DROP_PATH
img_size = cfg.MODEL.VIT.IMG_SIZE
pos_type = cfg.MODEL.VIT.POS_TYPE
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
if 'layoutlmv3' in name:
if cfg.MODEL.CONFIG_PATH != '':
config_path = cfg.MODEL.CONFIG_PATH
else:
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
else:
config_path = None
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
@BACKBONE_REGISTRY.register()
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Create a VIT w/ FPN backbone.
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_VIT_backbone(cfg)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
This diff is collapsed.
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import warnings
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
from functools import partial
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w, self.num_patches_h = self.window_size
self.num_patches = self.window_size[0] * self.window_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(
1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
model_name='vit_base_patch16_224',
img_size=384,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.1,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_cfg=None,
pos_embed_interp=False,
random_init=False,
align_corners=False,
use_checkpoint=False,
num_extra_tokens=1,
out_features=None,
**kwargs,
):
super(ViT, self).__init__()
self.model_name = model_name
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.num_classes = num_classes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.hybrid_backbone = hybrid_backbone
self.norm_layer = norm_layer
self.norm_cfg = norm_cfg
self.pos_embed_interp = pos_embed_interp
self.random_init = random_init
self.align_corners = align_corners
self.use_checkpoint = use_checkpoint
self.num_extra_tokens = num_extra_tokens
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
# self.num_stages = self.depth
# self.out_indices = tuple(range(self.num_stages))
if self.hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
if self.num_extra_tokens == 2:
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=self.drop_rate)
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
for i in range(self.depth)])
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.num_extra_tokens==2:
trunc_normal_(self.dist_token, std=0.2)
self.apply(self._init_weights)
# self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
logger = get_root_logger()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _conv_filter(self, state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def to_2D(self, x):
n, hw, c = x.shape
h = w = int(math.sqrt(hw))
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def to_1D(self, x):
n, c, h, w = x.shape
x = x.reshape(n, c, -1).transpose(1, 2)
return x
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_extra_tokens
N = self.pos_embed.shape[1] - self.num_extra_tokens
if npatch == N and w == h:
return self.pos_embed
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
all_tokens = [self.cls_token.expand(B, -1, -1)]
if self.num_extra_tokens == 2:
dist_tokens = self.dist_token.expand(B, -1, -1)
all_tokens.append(dist_tokens)
all_tokens.append(x)
x = torch.cat(all_tokens, dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
# print(f"==========shape of x is {x.shape}==========")
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if i in self.out_indices:
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def deit_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=2,
**kwargs)
model.default_cfg = _cfg()
return model
def mae_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=1,
**kwargs)
model.default_cfg = _cfg()
return model
\ No newline at end of file
from .models import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
'''
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
'''
import json
import os
from pathlib import Path
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{park2019cord,
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
year={2019}
}
"""
_DESCRIPTION = """\
https://github.com/clovaai/cord/
"""
def quad_to_box(quad):
# test 87 is wrongly annotated
box = (
max(0, quad["x1"]),
max(0, quad["y1"]),
quad["x3"],
quad["y3"]
)
if box[3] < box[1]:
bbox = list(box)
tmp = bbox[3]
bbox[3] = bbox[1]
bbox[1] = tmp
box = tuple(bbox)
if box[2] < box[0]:
bbox = list(box)
tmp = bbox[2]
bbox[2] = bbox[0]
bbox[0] = tmp
box = tuple(bbox)
return box
def _get_drive_url(url):
base_url = 'https://drive.google.com/uc?id='
split_url = url.split('/')
return base_url + split_url[5]
_URLS = [
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
# If you failed to download the dataset through the automatic downloader,
# you can download it manually and modify the code to get the local dataset.
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
]
class CordConfig(datasets.BuilderConfig):
"""BuilderConfig for CORD"""
def __init__(self, **kwargs):
"""BuilderConfig for CORD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CordConfig, self).__init__(**kwargs)
class Cord(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
citation=_CITATION,
homepage="https://github.com/clovaai/cord/",
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
"""Uses local files located with data_dir"""
downloaded_file = dl_manager.download_and_extract(_URLS)
# move files from the second URL together with files from the first one.
dest = Path(downloaded_file[0])/"CORD"
for split in ["train", "dev", "test"]:
for file_type in ["image", "json"]:
if split == "test" and file_type == "json":
continue
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
for f in files:
os.rename(f, dest/split/file_type/f.name)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "json")
img_dir = os.path.join(filepath, "image")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["valid_line"]:
cur_line_bboxes = []
line_words, label = item["words"], item["category"]
line_words = [w for w in line_words if w["text"].strip() != ""]
if len(line_words) == 0:
continue
if label == "other":
for w in line_words:
words.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
else:
words.append(line_words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
for w in line_words[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
bboxes.extend(cur_line_bboxes)
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizerBase
from transformers.data.data_collator import (
DataCollatorMixin,
_torch_collate_batch,
)
from transformers.file_utils import PaddingStrategy
from typing import NewType
InputDataClass = NewType("InputDataClass", Any)
def pre_calc_rel_mat(segment_ids):
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
device=segment_ids.device, dtype=torch.bool)
for i in range(segment_ids.shape[0]):
for j in range(segment_ids.shape[1]):
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
return valid_span
@dataclass
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
images = None
if "images" in features[0]:
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if images is not None:
batch["images"] = images
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
for k, v in batch.items()}
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
if labels is None:
return batch
has_bbox_input = "bbox" in features[0]
has_position_input = "position_ids" in features[0]
padding_idx=self.tokenizer.pad_token_id
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
for position_id in batch["position_ids"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+ position_id for position_id in batch["position_ids"]]
if 'segment_ids' in batch:
assert 'position_ids' in batch
for i in range(len(batch['segment_ids'])):
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if 'segment_ids' in batch:
valid_span = pre_calc_rel_mat(
segment_ids=batch['segment_ids']
)
batch['valid_span'] = valid_span
del batch['segment_ids']
if images is not None:
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
return batch
# coding=utf-8
'''
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
'''
import json
import os
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""
class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""
def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)
class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""
BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
cur_line_bboxes = []
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(w["box"], size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
# box = normalize_bbox(item["box"], size)
# cur_line_bboxes = [box for _ in range(len(words))]
bboxes.extend(cur_line_bboxes)
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
\ No newline at end of file
import torchvision.transforms.functional as F
import warnings
import math
import random
import numpy as np
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
def crop(image, i, j, h, w, boxes=None):
cropped_image = F.crop(image, i, j, h, w)
if boxes is not None:
# Currently we cannot use this case since when some boxes is out of the cropped image,
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
# which haven't been implemented here
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
boxes = cropped_boxes.reshape(-1, 4)
return cropped_image, boxes
def resize(image, size, interpolation, boxes=None):
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
# which is compatible with a square image size of 224x224
rescaled_image = F.resize(image, size, interpolation)
if boxes is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
# boxes = boxes.copy()
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
return rescaled_image, scaled_boxes
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def get_bb(bb, page_size):
bbs = [float(j) for j in bb]
xs, ys = [], []
for i, b in enumerate(bbs):
if i % 2 == 0:
xs.append(b)
else:
ys.append(b)
(width, height) = page_size
return_bb = [
clamp(min(xs), 0, width - 1),
clamp(min(ys), 0, height - 1),
clamp(max(xs), 0, width - 1),
clamp(max(ys), 0, height - 1),
]
return_bb = [
int(1000 * return_bb[0] / width),
int(1000 * return_bb[1] / height),
int(1000 * return_bb[2] / width),
int(1000 * return_bb[3] / height),
]
return return_bb
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return F.InterpolationMode.BICUBIC
elif method == 'lanczos':
return F.InterpolationMode.LANCZOS
elif method == 'hamming':
return F.InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return F.InterpolationMode.BILINEAR
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, augmentation=False, box=None):
for t in self.transforms:
img = t(img, augmentation, box)
return img
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, augmentation=False, box=None):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
if augmentation:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.crop(img, i, j, h, w)
# img, box = crop(img, i, j, h, w, box)
img = F.resize(img, self.size, self.interpolation)
second_img = F.resize(img, self.second_size, self.second_interpolation) \
if self.second_size is not None else None
return img, second_img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
import os
import json
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
XFund_label2ids = {
"O":0,
'B-HEADER':1,
'I-HEADER':2,
'B-QUESTION':3,
'I-QUESTION':4,
'B-ANSWER':5,
'I-ANSWER':6,
}
class xfund_dataset(Dataset):
def box_norm(self, box, width, height):
def clip(min_num, num, max_num):
return min(max(num, min_num), max_num)
x0, y0, x1, y1 = box
x0 = clip(0, int((x0 / width) * 1000), 1000)
y0 = clip(0, int((y0 / height) * 1000), 1000)
x1 = clip(0, int((x1 / width) * 1000), 1000)
y1 = clip(0, int((y1 / height) * 1000), 1000)
assert x1 >= x0
assert y1 >= y0
return [x0, y0, x1, y1]
def get_segment_ids(self, bboxs):
segment_ids = []
for i in range(len(bboxs)):
if i == 0:
segment_ids.append(0)
else:
if bboxs[i - 1] == bboxs[i]:
segment_ids.append(segment_ids[-1])
else:
segment_ids.append(segment_ids[-1] + 1)
return segment_ids
def get_position_ids(self, segment_ids):
position_ids = []
for i in range(len(segment_ids)):
if i == 0:
position_ids.append(2)
else:
if segment_ids[i] == segment_ids[i - 1]:
position_ids.append(position_ids[-1] + 1)
else:
position_ids.append(2)
return position_ids
def load_data(
self,
data_file,
):
# re-org data format
total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
for i in range(len(data_file['documents'])):
width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
'height']
cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
for j in range(len(data_file['documents'][i]['document'])):
cur_item = data_file['documents'][i]['document'][j]
cur_doc_lines.append(cur_item['text'])
cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
cur_doc_ner_tags.append(cur_item['label'])
total_data['id'] += [len(total_data['id'])]
total_data['lines'] += [cur_doc_lines]
total_data['bboxes'] += [cur_doc_bboxes]
total_data['ner_tags'] += [cur_doc_ner_tags]
total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
# tokenize text and get bbox/label
total_input_ids, total_bboxs, total_label_ids = [], [], []
for i in range(len(total_data['lines'])):
cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
for j in range(len(total_data['lines'][i])):
cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
if len(cur_input_ids) == 0: continue
cur_label = total_data['ner_tags'][i][j].upper()
if cur_label == 'OTHER':
cur_labels = ["O"] * len(cur_input_ids)
for k in range(len(cur_labels)):
cur_labels[k] = self.label2ids[cur_labels[k]]
else:
cur_labels = [cur_label] * len(cur_input_ids)
cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
for k in range(1, len(cur_labels)):
cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
cur_doc_input_ids += cur_input_ids
cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
cur_doc_labels += cur_labels
assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
assert len(cur_doc_input_ids) > 0
total_input_ids.append(cur_doc_input_ids)
total_bboxs.append(cur_doc_bboxs)
total_label_ids.append(cur_doc_labels)
assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
# split text to several slices because of over-length
input_ids, bboxs, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
for i in range(len(total_input_ids)):
start = 0
cur_iter = 0
while start < len(total_input_ids[i]):
end = min(start + 510, len(total_input_ids[i]))
input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
labels.append([-100] + total_label_ids[i][start: end] + [-100])
cur_segment_ids = self.get_segment_ids(bboxs[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
start = end
cur_iter += 1
assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)
res = {
'input_ids': input_ids,
'bbox': bboxs,
'labels': labels,
'segment_ids': segment_ids,
'position_ids': position_ids,
'image_path': image_path,
}
return res
def __init__(
self,
args,
tokenizer,
mode
):
self.args = args
self.mode = mode
self.cur_la = args.language
self.tokenizer = tokenizer
self.label2ids = XFund_label2ids
self.common_transform = Compose([
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, interpolation=args.train_interpolation,
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor((0.5, 0.5, 0.5)),
std=torch.tensor((0.5, 0.5, 0.5)))
])
data_file = json.load(
open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
'r'))
self.feature = self.load_data(data_file)
def __len__(self):
return len(self.feature['input_ids'])
def __getitem__(self, index):
input_ids = self.feature["input_ids"][index]
# attention_mask = self.feature["attention_mask"][index]
attention_mask = [1] * len(input_ids)
labels = self.feature["labels"][index]
bbox = self.feature["bbox"][index]
segment_ids = self.feature['segment_ids'][index]
position_ids = self.feature['position_ids'][index]
img = pil_loader(self.feature['image_path'][index])
for_patches, _ = self.common_transform(img, augmentation=False)
patch = self.patch_transform(for_patches)
assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"bbox": bbox,
"segment_ids": segment_ids,
"position_ids": position_ids,
"images": patch,
}
return res
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
\ No newline at end of file
from .layoutlmv3 import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
from .configuration_layoutlmv3 import LayoutLMv3Config
from .modeling_layoutlmv3 import (
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Model,
)
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
#AutoConfig.register("layoutlmv3", LayoutLMv3Config)
#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
#AutoTokenizer.register(
# LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
#)
SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})
# coding=utf-8
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
"layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json",
# See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
}
class LayoutLMv3Config(BertConfig):
model_type = "layoutlmv3"
def __init__(
self,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
max_2d_position_embeddings=1024,
coordinate_size=None,
shape_size=None,
has_relative_attention_bias=False,
rel_pos_bins=32,
max_rel_pos=128,
has_spatial_attention_bias=False,
rel_2d_pos_bins=64,
max_rel_2d_pos=256,
visual_embed=True,
mim=False,
wpa_task=False,
discrete_vae_weight_path='',
discrete_vae_type='dall-e',
input_size=224,
second_input_size=112,
device='cuda',
**kwargs
):
"""Constructs RobertaConfig."""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.max_2d_position_embeddings = max_2d_position_embeddings
self.coordinate_size = coordinate_size
self.shape_size = shape_size
self.has_relative_attention_bias = has_relative_attention_bias
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.has_spatial_attention_bias = has_spatial_attention_bias
self.rel_2d_pos_bins = rel_2d_pos_bins
self.max_rel_2d_pos = max_rel_2d_pos
self.visual_embed = visual_embed
self.mim = mim
self.wpa_task = wpa_task
self.discrete_vae_weight_path = discrete_vae_weight_path
self.discrete_vae_type = discrete_vae_type
self.input_size = input_size
self.second_input_size = second_input_size
self.device = device
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment