Commit 4a823359 authored by quyuan's avatar quyuan
Browse files

Merge branch 'master' of https://github.com/opendatalab/MinerU

parents 611e2f59 b6df9b18
# --------------------------------------------------------------------------------
# VIT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# CoaT: https://github.com/mlpc-ucsd/CoaT
# --------------------------------------------------------------------------------
import torch
from detectron2.layers import (
ShapeSpec,
)
from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
from .deit import deit_base_patch16, mae_base_patch16
from .layoutlmft.models.layoutlmv3 import LayoutLMv3Model
from transformers import AutoConfig
__all__ = [
"build_vit_fpn_backbone",
]
class VIT_Backbone(Backbone):
"""
Implement VIT backbone.
"""
def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=None, image_only=False, cfg=None):
super().__init__()
self._out_features = out_features
if 'base' in name:
self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
else:
self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
if name == 'beit_base_patch16':
model_func = beit_base_patch16
elif name == 'dit_base_patch16':
model_func = dit_base_patch16
elif name == "deit_base_patch16":
model_func = deit_base_patch16
elif name == "mae_base_patch16":
model_func = mae_base_patch16
elif name == "dit_large_patch16":
model_func = dit_large_patch16
elif name == "beit_large_patch16":
model_func = beit_large_patch16
if 'beit' in name or 'dit' in name:
if pos_type == "abs":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_abs_pos_emb=True,
**model_kwargs)
elif pos_type == "shared_rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_shared_rel_pos_bias=True,
**model_kwargs)
elif pos_type == "rel":
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
use_rel_pos_bias=True,
**model_kwargs)
else:
raise ValueError()
elif "layoutlmv3" in name:
config = AutoConfig.from_pretrained(config_path)
# disable relative bias as DiT
config.has_spatial_attention_bias = False
config.has_relative_attention_bias = False
self.backbone = LayoutLMv3Model(config, detection=True,
out_features=out_features, image_only=image_only)
else:
self.backbone = model_func(img_size=img_size,
out_features=out_features,
drop_path_rate=drop_path,
**model_kwargs)
self.name = name
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
if "layoutlmv3" in self.name:
return self.backbone.forward(
input_ids=x["input_ids"] if "input_ids" in x else None,
bbox=x["bbox"] if "bbox" in x else None,
images=x["images"] if "images" in x else None,
attention_mask=x["attention_mask"] if "attention_mask" in x else None,
# output_hidden_states=True,
)
assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
return self.backbone.forward_features(x)
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def build_VIT_backbone(cfg):
"""
Create a VIT instance from config.
Args:
cfg: a detectron2 CfgNode
Returns:
A VIT backbone instance.
"""
# fmt: off
name = cfg.MODEL.VIT.NAME
out_features = cfg.MODEL.VIT.OUT_FEATURES
drop_path = cfg.MODEL.VIT.DROP_PATH
img_size = cfg.MODEL.VIT.IMG_SIZE
pos_type = cfg.MODEL.VIT.POS_TYPE
model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
if 'layoutlmv3' in name:
if cfg.MODEL.CONFIG_PATH != '':
config_path = cfg.MODEL.CONFIG_PATH
else:
config_path = cfg.MODEL.WEIGHTS.replace('pytorch_model.bin', '') # layoutlmv3 pre-trained models
config_path = config_path.replace('model_final.pth', '') # detection fine-tuned models
else:
config_path = None
return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs,
config_path=config_path, image_only=cfg.MODEL.IMAGE_ONLY, cfg=cfg)
@BACKBONE_REGISTRY.register()
def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
"""
Create a VIT w/ FPN backbone.
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_VIT_backbone(cfg)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FPN(
bottom_up=bottom_up,
in_features=in_features,
out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelMaxPool(),
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
This diff is collapsed.
"""
Mostly copy-paste from DINO and timm library:
https://github.com/facebookresearch/dino
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import warnings
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import trunc_normal_, drop_path, to_2tuple
from functools import partial
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches_w, self.num_patches_h = self.window_size
self.num_patches = self.window_size[0] * self.window_size[1]
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(
1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class ViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
model_name='vit_base_patch16_224',
img_size=384,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
num_classes=19,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.1,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_cfg=None,
pos_embed_interp=False,
random_init=False,
align_corners=False,
use_checkpoint=False,
num_extra_tokens=1,
out_features=None,
**kwargs,
):
super(ViT, self).__init__()
self.model_name = model_name
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.num_classes = num_classes
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.hybrid_backbone = hybrid_backbone
self.norm_layer = norm_layer
self.norm_cfg = norm_cfg
self.pos_embed_interp = pos_embed_interp
self.random_init = random_init
self.align_corners = align_corners
self.use_checkpoint = use_checkpoint
self.num_extra_tokens = num_extra_tokens
self.out_features = out_features
self.out_indices = [int(name[5:]) for name in out_features]
# self.num_stages = self.depth
# self.out_indices = tuple(range(self.num_stages))
if self.hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
if self.num_extra_tokens == 2:
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(
1, self.num_patches + self.num_extra_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=self.drop_rate)
# self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
for i in range(self.depth)])
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
# self.repr = nn.Linear(embed_dim, representation_size)
# self.repr_act = nn.Tanh()
if patch_size == 16:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
elif patch_size == 8:
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Identity()
self.fpn3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.num_extra_tokens==2:
trunc_normal_(self.dist_token, std=0.2)
self.apply(self._init_weights)
# self.fix_init_weight()
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
'''
def init_weights(self):
logger = get_root_logger()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
if self.init_cfg is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
'''
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _conv_filter(self, state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
def to_2D(self, x):
n, hw, c = x.shape
h = w = int(math.sqrt(hw))
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def to_1D(self, x):
n, c, h, w = x.shape
x = x.reshape(n, c, -1).transpose(1, 2)
return x
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - self.num_extra_tokens
N = self.pos_embed.shape[1] - self.num_extra_tokens
if npatch == N and w == h:
return self.pos_embed
class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size[0]
h0 = h // self.patch_embed.patch_size[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
def prepare_tokens(self, x, mask=None):
B, nc, w, h = x.shape
# patch linear embedding
x = self.patch_embed(x)
# mask image modeling
if mask is not None:
x = self.mask_model(x, mask)
x = x.flatten(2).transpose(1, 2)
# add the [CLS] token to the embed patch tokens
all_tokens = [self.cls_token.expand(B, -1, -1)]
if self.num_extra_tokens == 2:
dist_tokens = self.dist_token.expand(B, -1, -1)
all_tokens.append(dist_tokens)
all_tokens.append(x)
x = torch.cat(all_tokens, dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward_features(self, x):
# print(f"==========shape of x is {x.shape}==========")
B, _, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.prepare_tokens(x)
features = []
for i, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if i in self.out_indices:
xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
features.append(xp.contiguous())
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
feat_out = {}
for name, value in zip(self.out_features, features):
feat_out[name] = value
return feat_out
def forward(self, x):
x = self.forward_features(x)
return x
def deit_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=2,
**kwargs)
model.default_cfg = _cfg()
return model
def mae_base_patch16(pretrained=False, **kwargs):
model = ViT(
patch_size=16,
drop_rate=0.,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
use_checkpoint=True,
num_extra_tokens=1,
**kwargs)
model.default_cfg = _cfg()
return model
\ No newline at end of file
from .models import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
'''
Reference: https://huggingface.co/datasets/pierresi/cord/blob/main/cord.py
'''
import json
import os
from pathlib import Path
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{park2019cord,
title={CORD: A Consolidated Receipt Dataset for Post-OCR Parsing},
author={Park, Seunghyun and Shin, Seung and Lee, Bado and Lee, Junyeop and Surh, Jaeheung and Seo, Minjoon and Lee, Hwalsuk}
booktitle={Document Intelligence Workshop at Neural Information Processing Systems}
year={2019}
}
"""
_DESCRIPTION = """\
https://github.com/clovaai/cord/
"""
def quad_to_box(quad):
# test 87 is wrongly annotated
box = (
max(0, quad["x1"]),
max(0, quad["y1"]),
quad["x3"],
quad["y3"]
)
if box[3] < box[1]:
bbox = list(box)
tmp = bbox[3]
bbox[3] = bbox[1]
bbox[1] = tmp
box = tuple(bbox)
if box[2] < box[0]:
bbox = list(box)
tmp = bbox[2]
bbox[2] = bbox[0]
bbox[0] = tmp
box = tuple(bbox)
return box
def _get_drive_url(url):
base_url = 'https://drive.google.com/uc?id='
split_url = url.split('/')
return base_url + split_url[5]
_URLS = [
_get_drive_url("https://drive.google.com/file/d/1MqhTbcj-AHXOqYoeoh12aRUwIprzTJYI/"),
_get_drive_url("https://drive.google.com/file/d/1wYdp5nC9LnHQZ2FcmOoC0eClyWvcuARU/")
# If you failed to download the dataset through the automatic downloader,
# you can download it manually and modify the code to get the local dataset.
# Or you can use the following links. Please follow the original LICENSE of CORD for usage.
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-001.zip",
# "https://layoutlm.blob.core.windows.net/cord/CORD-1k-002.zip"
]
class CordConfig(datasets.BuilderConfig):
"""BuilderConfig for CORD"""
def __init__(self, **kwargs):
"""BuilderConfig for CORD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(CordConfig, self).__init__(**kwargs)
class Cord(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CordConfig(name="cord", version=datasets.Version("1.0.0"), description="CORD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"words": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O","B-MENU.NM","B-MENU.NUM","B-MENU.UNITPRICE","B-MENU.CNT","B-MENU.DISCOUNTPRICE","B-MENU.PRICE","B-MENU.ITEMSUBTOTAL","B-MENU.VATYN","B-MENU.ETC","B-MENU.SUB_NM","B-MENU.SUB_UNITPRICE","B-MENU.SUB_CNT","B-MENU.SUB_PRICE","B-MENU.SUB_ETC","B-VOID_MENU.NM","B-VOID_MENU.PRICE","B-SUB_TOTAL.SUBTOTAL_PRICE","B-SUB_TOTAL.DISCOUNT_PRICE","B-SUB_TOTAL.SERVICE_PRICE","B-SUB_TOTAL.OTHERSVC_PRICE","B-SUB_TOTAL.TAX_PRICE","B-SUB_TOTAL.ETC","B-TOTAL.TOTAL_PRICE","B-TOTAL.TOTAL_ETC","B-TOTAL.CASHPRICE","B-TOTAL.CHANGEPRICE","B-TOTAL.CREDITCARDPRICE","B-TOTAL.EMONEYPRICE","B-TOTAL.MENUTYPE_CNT","B-TOTAL.MENUQTY_CNT","I-MENU.NM","I-MENU.NUM","I-MENU.UNITPRICE","I-MENU.CNT","I-MENU.DISCOUNTPRICE","I-MENU.PRICE","I-MENU.ITEMSUBTOTAL","I-MENU.VATYN","I-MENU.ETC","I-MENU.SUB_NM","I-MENU.SUB_UNITPRICE","I-MENU.SUB_CNT","I-MENU.SUB_PRICE","I-MENU.SUB_ETC","I-VOID_MENU.NM","I-VOID_MENU.PRICE","I-SUB_TOTAL.SUBTOTAL_PRICE","I-SUB_TOTAL.DISCOUNT_PRICE","I-SUB_TOTAL.SERVICE_PRICE","I-SUB_TOTAL.OTHERSVC_PRICE","I-SUB_TOTAL.TAX_PRICE","I-SUB_TOTAL.ETC","I-TOTAL.TOTAL_PRICE","I-TOTAL.TOTAL_ETC","I-TOTAL.CASHPRICE","I-TOTAL.CHANGEPRICE","I-TOTAL.CREDITCARDPRICE","I-TOTAL.EMONEYPRICE","I-TOTAL.MENUTYPE_CNT","I-TOTAL.MENUQTY_CNT"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
citation=_CITATION,
homepage="https://github.com/clovaai/cord/",
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
"""Uses local files located with data_dir"""
downloaded_file = dl_manager.download_and_extract(_URLS)
# move files from the second URL together with files from the first one.
dest = Path(downloaded_file[0])/"CORD"
for split in ["train", "dev", "test"]:
for file_type in ["image", "json"]:
if split == "test" and file_type == "json":
continue
files = (Path(downloaded_file[1])/"CORD"/split/file_type).iterdir()
for f in files:
os.rename(f, dest/split/file_type/f.name)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": dest/"train"}
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION, gen_kwargs={"filepath": dest/"dev"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": dest/"test"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "json")
img_dir = os.path.join(filepath, "image")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
words = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["valid_line"]:
cur_line_bboxes = []
line_words, label = item["words"], item["category"]
line_words = [w for w in line_words if w["text"].strip() != ""]
if len(line_words) == 0:
continue
if label == "other":
for w in line_words:
words.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
else:
words.append(line_words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(line_words[0]["quad"]), size))
for w in line_words[1:]:
words.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(quad_to_box(w["quad"]), size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
bboxes.extend(cur_line_bboxes)
# yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
yield guid, {"id": str(guid), "words": words, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizerBase
from transformers.data.data_collator import (
DataCollatorMixin,
_torch_collate_batch,
)
from transformers.file_utils import PaddingStrategy
from typing import NewType
InputDataClass = NewType("InputDataClass", Any)
def pre_calc_rel_mat(segment_ids):
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
device=segment_ids.device, dtype=torch.bool)
for i in range(segment_ids.shape[0]):
for j in range(segment_ids.shape[1]):
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
return valid_span
@dataclass
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
images = None
if "images" in features[0]:
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)
if images is not None:
batch["images"] = images
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
for k, v in batch.items()}
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
if labels is None:
return batch
has_bbox_input = "bbox" in features[0]
has_position_input = "position_ids" in features[0]
padding_idx=self.tokenizer.pad_token_id
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
for position_id in batch["position_ids"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
if has_position_input:
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
+ position_id for position_id in batch["position_ids"]]
if 'segment_ids' in batch:
assert 'position_ids' in batch
for i in range(len(batch['segment_ids'])):
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if 'segment_ids' in batch:
valid_span = pre_calc_rel_mat(
segment_ids=batch['segment_ids']
)
batch['valid_span'] = valid_span
del batch['segment_ids']
if images is not None:
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
return batch
# coding=utf-8
'''
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
'''
import json
import os
import datasets
from .image_utils import load_image, normalize_bbox
logger = datasets.logging.get_logger(__name__)
_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""
_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""
class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""
def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)
class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""
BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
"image_path": datasets.Value("string"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]
def get_line_bbox(self, bboxs):
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
assert x1 >= x0 and y1 >= y0
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
return bbox
def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []
file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
cur_line_bboxes = []
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
cur_line_bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
cur_line_bboxes.append(normalize_bbox(w["box"], size))
# by default: --segment_level_layout 1
# if do not want to use segment_level_layout, comment the following line
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
# box = normalize_bbox(item["box"], size)
# cur_line_bboxes = [box for _ in range(len(words))]
bboxes.extend(cur_line_bboxes)
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
"image": image, "image_path": image_path}
\ No newline at end of file
import torchvision.transforms.functional as F
import warnings
import math
import random
import numpy as np
from PIL import Image
import torch
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList
def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]
def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
def crop(image, i, j, h, w, boxes=None):
cropped_image = F.crop(image, i, j, h, w)
if boxes is not None:
# Currently we cannot use this case since when some boxes is out of the cropped image,
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
# which haven't been implemented here
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
cropped_boxes = cropped_boxes.clamp(min=0)
boxes = cropped_boxes.reshape(-1, 4)
return cropped_image, boxes
def resize(image, size, interpolation, boxes=None):
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
# which is compatible with a square image size of 224x224
rescaled_image = F.resize(image, size, interpolation)
if boxes is None:
return rescaled_image, None
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
ratio_width, ratio_height = ratios
# boxes = boxes.copy()
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
return rescaled_image, scaled_boxes
def clamp(num, min_value, max_value):
return max(min(num, max_value), min_value)
def get_bb(bb, page_size):
bbs = [float(j) for j in bb]
xs, ys = [], []
for i, b in enumerate(bbs):
if i % 2 == 0:
xs.append(b)
else:
ys.append(b)
(width, height) = page_size
return_bb = [
clamp(min(xs), 0, width - 1),
clamp(min(ys), 0, height - 1),
clamp(max(xs), 0, width - 1),
clamp(max(ys), 0, height - 1),
]
return_bb = [
int(1000 * return_bb[0] / width),
int(1000 * return_bb[1] / height),
int(1000 * return_bb[2] / width),
int(1000 * return_bb[3] / height),
]
return return_bb
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
}
def _pil_interp(method):
if method == 'bicubic':
return F.InterpolationMode.BICUBIC
elif method == 'lanczos':
return F.InterpolationMode.LANCZOS
elif method == 'hamming':
return F.InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return F.InterpolationMode.BILINEAR
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, augmentation=False, box=None):
for t in self.transforms:
img = t(img, augmentation, box)
return img
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear', second_interpolation='lanczos'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
self.interpolation = _pil_interp(interpolation)
self.second_interpolation = _pil_interp(second_interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img, augmentation=False, box=None):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
if augmentation:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.crop(img, i, j, h, w)
# img, box = crop(img, i, j, h, w, box)
img = F.resize(img, self.size, self.interpolation)
second_img = F.resize(img, self.second_size, self.second_interpolation) \
if self.second_size is not None else None
return img, second_img
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0}'.format(interpolate_str)
if self.second_size is not None:
format_string += ', second_size={0}'.format(self.second_size)
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
format_string += ')'
return format_string
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
import os
import json
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
XFund_label2ids = {
"O":0,
'B-HEADER':1,
'I-HEADER':2,
'B-QUESTION':3,
'I-QUESTION':4,
'B-ANSWER':5,
'I-ANSWER':6,
}
class xfund_dataset(Dataset):
def box_norm(self, box, width, height):
def clip(min_num, num, max_num):
return min(max(num, min_num), max_num)
x0, y0, x1, y1 = box
x0 = clip(0, int((x0 / width) * 1000), 1000)
y0 = clip(0, int((y0 / height) * 1000), 1000)
x1 = clip(0, int((x1 / width) * 1000), 1000)
y1 = clip(0, int((y1 / height) * 1000), 1000)
assert x1 >= x0
assert y1 >= y0
return [x0, y0, x1, y1]
def get_segment_ids(self, bboxs):
segment_ids = []
for i in range(len(bboxs)):
if i == 0:
segment_ids.append(0)
else:
if bboxs[i - 1] == bboxs[i]:
segment_ids.append(segment_ids[-1])
else:
segment_ids.append(segment_ids[-1] + 1)
return segment_ids
def get_position_ids(self, segment_ids):
position_ids = []
for i in range(len(segment_ids)):
if i == 0:
position_ids.append(2)
else:
if segment_ids[i] == segment_ids[i - 1]:
position_ids.append(position_ids[-1] + 1)
else:
position_ids.append(2)
return position_ids
def load_data(
self,
data_file,
):
# re-org data format
total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
for i in range(len(data_file['documents'])):
width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
'height']
cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
for j in range(len(data_file['documents'][i]['document'])):
cur_item = data_file['documents'][i]['document'][j]
cur_doc_lines.append(cur_item['text'])
cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
cur_doc_ner_tags.append(cur_item['label'])
total_data['id'] += [len(total_data['id'])]
total_data['lines'] += [cur_doc_lines]
total_data['bboxes'] += [cur_doc_bboxes]
total_data['ner_tags'] += [cur_doc_ner_tags]
total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
# tokenize text and get bbox/label
total_input_ids, total_bboxs, total_label_ids = [], [], []
for i in range(len(total_data['lines'])):
cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
for j in range(len(total_data['lines'][i])):
cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
if len(cur_input_ids) == 0: continue
cur_label = total_data['ner_tags'][i][j].upper()
if cur_label == 'OTHER':
cur_labels = ["O"] * len(cur_input_ids)
for k in range(len(cur_labels)):
cur_labels[k] = self.label2ids[cur_labels[k]]
else:
cur_labels = [cur_label] * len(cur_input_ids)
cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
for k in range(1, len(cur_labels)):
cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
cur_doc_input_ids += cur_input_ids
cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
cur_doc_labels += cur_labels
assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
assert len(cur_doc_input_ids) > 0
total_input_ids.append(cur_doc_input_ids)
total_bboxs.append(cur_doc_bboxs)
total_label_ids.append(cur_doc_labels)
assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
# split text to several slices because of over-length
input_ids, bboxs, labels = [], [], []
segment_ids, position_ids = [], []
image_path = []
for i in range(len(total_input_ids)):
start = 0
cur_iter = 0
while start < len(total_input_ids[i]):
end = min(start + 510, len(total_input_ids[i]))
input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
labels.append([-100] + total_label_ids[i][start: end] + [-100])
cur_segment_ids = self.get_segment_ids(bboxs[-1])
cur_position_ids = self.get_position_ids(cur_segment_ids)
segment_ids.append(cur_segment_ids)
position_ids.append(cur_position_ids)
image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
start = end
cur_iter += 1
assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
assert len(segment_ids) == len(image_path)
res = {
'input_ids': input_ids,
'bbox': bboxs,
'labels': labels,
'segment_ids': segment_ids,
'position_ids': position_ids,
'image_path': image_path,
}
return res
def __init__(
self,
args,
tokenizer,
mode
):
self.args = args
self.mode = mode
self.cur_la = args.language
self.tokenizer = tokenizer
self.label2ids = XFund_label2ids
self.common_transform = Compose([
RandomResizedCropAndInterpolationWithTwoPic(
size=args.input_size, interpolation=args.train_interpolation,
),
])
self.patch_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor((0.5, 0.5, 0.5)),
std=torch.tensor((0.5, 0.5, 0.5)))
])
data_file = json.load(
open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
'r'))
self.feature = self.load_data(data_file)
def __len__(self):
return len(self.feature['input_ids'])
def __getitem__(self, index):
input_ids = self.feature["input_ids"][index]
# attention_mask = self.feature["attention_mask"][index]
attention_mask = [1] * len(input_ids)
labels = self.feature["labels"][index]
bbox = self.feature["bbox"][index]
segment_ids = self.feature['segment_ids'][index]
position_ids = self.feature['position_ids'][index]
img = pil_loader(self.feature['image_path'][index])
for_patches, _ = self.common_transform(img, augmentation=False)
patch = self.patch_transform(for_patches)
assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
res = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"bbox": bbox,
"segment_ids": segment_ids,
"position_ids": position_ids,
"images": patch,
}
return res
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
\ No newline at end of file
from .layoutlmv3 import (
LayoutLMv3Config,
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Tokenizer,
)
from transformers import AutoConfig, AutoModel, AutoModelForTokenClassification, \
AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, RobertaConverter
from .configuration_layoutlmv3 import LayoutLMv3Config
from .modeling_layoutlmv3 import (
LayoutLMv3ForTokenClassification,
LayoutLMv3ForQuestionAnswering,
LayoutLMv3ForSequenceClassification,
LayoutLMv3Model,
)
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
#AutoConfig.register("layoutlmv3", LayoutLMv3Config)
#AutoModel.register(LayoutLMv3Config, LayoutLMv3Model)
#AutoModelForTokenClassification.register(LayoutLMv3Config, LayoutLMv3ForTokenClassification)
#AutoModelForQuestionAnswering.register(LayoutLMv3Config, LayoutLMv3ForQuestionAnswering)
#AutoModelForSequenceClassification.register(LayoutLMv3Config, LayoutLMv3ForSequenceClassification)
#AutoTokenizer.register(
# LayoutLMv3Config, slow_tokenizer_class=LayoutLMv3Tokenizer, fast_tokenizer_class=LayoutLMv3TokenizerFast
#)
SLOW_TO_FAST_CONVERTERS.update({"LayoutLMv3Tokenizer": RobertaConverter})
# coding=utf-8
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
"layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/resolve/main/config.json",
# See all LayoutLMv3 models at https://huggingface.co/models?filter=layoutlmv3
}
class LayoutLMv3Config(BertConfig):
model_type = "layoutlmv3"
def __init__(
self,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
max_2d_position_embeddings=1024,
coordinate_size=None,
shape_size=None,
has_relative_attention_bias=False,
rel_pos_bins=32,
max_rel_pos=128,
has_spatial_attention_bias=False,
rel_2d_pos_bins=64,
max_rel_2d_pos=256,
visual_embed=True,
mim=False,
wpa_task=False,
discrete_vae_weight_path='',
discrete_vae_type='dall-e',
input_size=224,
second_input_size=112,
device='cuda',
**kwargs
):
"""Constructs RobertaConfig."""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.max_2d_position_embeddings = max_2d_position_embeddings
self.coordinate_size = coordinate_size
self.shape_size = shape_size
self.has_relative_attention_bias = has_relative_attention_bias
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.has_spatial_attention_bias = has_spatial_attention_bias
self.rel_2d_pos_bins = rel_2d_pos_bins
self.max_rel_2d_pos = max_rel_2d_pos
self.visual_embed = visual_embed
self.mim = mim
self.wpa_task = wpa_task
self.discrete_vae_weight_path = discrete_vae_weight_path
self.discrete_vae_type = discrete_vae_type
self.input_size = input_size
self.second_input_size = second_input_size
self.device = device
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LayoutLMv3, refer to RoBERTa."""
from transformers.models.roberta import RobertaTokenizer
from transformers.utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
class LayoutLMv3Tokenizer(RobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
# pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Tokenization classes for LayoutLMv3, refer to RoBERTa."""
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
from transformers.utils import logging
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
class LayoutLMv3TokenizerFast(RobertaTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
# pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
# max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = LayoutLMv3Tokenizer
from .visualizer import Visualizer
from .rcnn_vl import *
from .backbone import *
from detectron2.config import get_cfg
from detectron2.config import CfgNode as CN
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
def add_vit_config(cfg):
"""
Add config for VIT.
"""
_C = cfg
_C.MODEL.VIT = CN()
# CoaT model name.
_C.MODEL.VIT.NAME = ""
# Output features from CoaT backbone.
_C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
_C.MODEL.VIT.IMG_SIZE = [224, 224]
_C.MODEL.VIT.POS_TYPE = "shared_rel"
_C.MODEL.VIT.DROP_PATH = 0.
_C.MODEL.VIT.MODEL_KWARGS = "{}"
_C.SOLVER.OPTIMIZER = "ADAMW"
_C.SOLVER.BACKBONE_MULTIPLIER = 1.0
_C.AUG = CN()
_C.AUG.DETR = False
_C.MODEL.IMAGE_ONLY = True
_C.PUBLAYNET_DATA_DIR_TRAIN = ""
_C.PUBLAYNET_DATA_DIR_TEST = ""
_C.FOOTNOTE_DATA_DIR_TRAIN = ""
_C.FOOTNOTE_DATA_DIR_VAL = ""
_C.SCIHUB_DATA_DIR_TRAIN = ""
_C.SCIHUB_DATA_DIR_TEST = ""
_C.JIAOCAI_DATA_DIR_TRAIN = ""
_C.JIAOCAI_DATA_DIR_TEST = ""
_C.ICDAR_DATA_DIR_TRAIN = ""
_C.ICDAR_DATA_DIR_TEST = ""
_C.M6DOC_DATA_DIR_TEST = ""
_C.DOCSTRUCTBENCH_DATA_DIR_TEST = ""
_C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = ""
_C.CACHE_DIR = ""
_C.MODEL.CONFIG_PATH = ""
# effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS
# maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS
_C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
def setup(args, device):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# add_coat_config(cfg)
add_vit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
cfg.merge_from_list(args.opts)
# 使用统一的device配置
cfg.MODEL.DEVICE = device
cfg.freeze()
default_setup(cfg, args)
register_coco_instances(
"scihub_train",
{},
cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
cfg.SCIHUB_DATA_DIR_TRAIN
)
return cfg
class DotDict(dict):
def __init__(self, *args, **kwargs):
super(DotDict, self).__init__(*args, **kwargs)
def __getattr__(self, key):
if key not in self.keys():
return None
value = self[key]
if isinstance(value, dict):
value = DotDict(value)
return value
def __setattr__(self, key, value):
self[key] = value
class Layoutlmv3_Predictor(object):
def __init__(self, weights, config_file, device):
layout_args = {
"config_file": config_file,
"resume": False,
"eval_only": False,
"num_gpus": 1,
"num_machines": 1,
"machine_rank": 0,
"dist_url": "tcp://127.0.0.1:57823",
"opts": ["MODEL.WEIGHTS", weights],
}
layout_args = DotDict(layout_args)
cfg = setup(layout_args, device)
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
"table_footnote", "isolate_formula", "formula_caption"]
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
self.predictor = DefaultPredictor(cfg)
def __call__(self, image, ignore_catids=[]):
# page_layout_result = {
# "layout_dets": []
# }
layout_dets = []
outputs = self.predictor(image)
boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
scores = outputs["instances"].to("cpu")._fields["scores"].tolist()
for bbox_idx in range(len(boxes)):
if labels[bbox_idx] in ignore_catids:
continue
layout_dets.append({
"category_id": labels[bbox_idx],
"poly": [
boxes[bbox_idx][0], boxes[bbox_idx][1],
boxes[bbox_idx][2], boxes[bbox_idx][1],
boxes[bbox_idx][2], boxes[bbox_idx][3],
boxes[bbox_idx][0], boxes[bbox_idx][3],
],
"score": scores[bbox_idx]
})
return layout_dets
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.structures import ImageList, Instances
from detectron2.utils.events import get_event_storage
from detectron2.modeling.backbone import Backbone, build_backbone
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch import GeneralizedRCNN
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image
from contextlib import contextmanager
from itertools import count
@META_ARCH_REGISTRY.register()
class VLGeneralizedRCNN(GeneralizedRCNN):
"""
Generalized R-CNN. Any models that contains the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
"""
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* proposals (optional): :class:`Instances`, precomputed proposals.
Other information that's included in the original dicts, such as:
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
list[dict]:
Each dict is the output for one input image.
The dict contains one key "instances" whose value is a :class:`Instances`.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
"""
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
# features = self.backbone(images.tensor)
input = self.get_batch(batched_inputs, images)
features = self.backbone(input)
if self.proposal_generator is not None:
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}
_, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
return losses
def inference(
self,
batched_inputs: List[Dict[str, torch.Tensor]],
detected_instances: Optional[List[Instances]] = None,
do_postprocess: bool = True,
):
"""
Run inference on the given inputs.
Args:
batched_inputs (list[dict]): same as in :meth:`forward`
detected_instances (None or list[Instances]): if not None, it
contains an `Instances` object per image. The `Instances`
object contains "pred_boxes" and "pred_classes" which are
known boxes in the image.
The inference will then skip the detection of bounding boxes,
and only predict other per-ROI outputs.
do_postprocess (bool): whether to apply post-processing on the outputs.
Returns:
When do_postprocess=True, same as in :meth:`forward`.
Otherwise, a list[Instances] containing raw network outputs.
"""
assert not self.training
images = self.preprocess_image(batched_inputs)
# features = self.backbone(images.tensor)
input = self.get_batch(batched_inputs, images)
features = self.backbone(input)
if detected_instances is None:
if self.proposal_generator is not None:
proposals, _ = self.proposal_generator(images, features, None)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
results, _ = self.roi_heads(images, features, proposals, None)
else:
detected_instances = [x.to(self.device) for x in detected_instances]
results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
if do_postprocess:
assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
else:
return results
def get_batch(self, examples, images):
if len(examples) >= 1 and "bbox" not in examples[0]: # image_only
return {"images": images.tensor}
return input
def _batch_inference(self, batched_inputs, detected_instances=None):
"""
Execute inference on a list of inputs,
using batch size = self.batch_size (e.g., 2), instead of the length of the list.
Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference`
"""
if detected_instances is None:
detected_instances = [None] * len(batched_inputs)
outputs = []
inputs, instances = [], []
for idx, input, instance in zip(count(), batched_inputs, detected_instances):
inputs.append(input)
instances.append(instance)
if len(inputs) == 2 or idx == len(batched_inputs) - 1:
outputs.extend(
self.inference(
inputs,
instances if instances[0] is not None else None,
do_postprocess=True, # False
)
)
inputs, instances = [], []
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment