Commit 055b6aa1 authored by chenych's avatar chenych
Browse files

First init

parent 6370e125
Pipeline #2993 failed with stages
in 0 seconds
# Contributors
None
\ No newline at end of file
import torch
import os
import argparse
from transformers import AutoModel, AutoTokenizer
os.environ["HIP_VISIBLE_DEVICES"] = '0'
parse = argparse.ArgumentParser()
parse.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-OCR')
parse.add_argument('--image_file', type=str, default='./doc/test.jpg')
parse.add_argument('--output_path', type=str, default='./output/')
args = parse.parse_args()
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model_name_or_path, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
model = model.eval().cuda().to(torch.bfloat16)
# prompt = "<image>\nFree OCR. "
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
res = model.infer(tokenizer, prompt=prompt, image_file=args.image_file, output_path=args.output_path, base_size=1024, image_size=640, crop_mode=True, save_results=True, test_compress=True)
print("process end, result saved to ", args.output_path)
\ No newline at end of file
# TODO: change modes
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
MIN_CROPS= 2
MAX_CROPS= 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = True
MODEL_PATH = 'deepseek-ai/DeepSeek-OCR' # change to your model path
# TODO: change INPUT_PATH
# .pdf: run_dpsk_ocr_pdf.py;
# .jpg, .png, .jpeg: run_dpsk_ocr_image.py;
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = ''
OUTPUT_PATH = ''
PROMPT = '<image>\n<|grounding|>Convert the document to markdown.'
# PROMPT = '<image>\nFree OCR.'
# TODO commonly used prompts
# document: <image>\n<|grounding|>Convert the document to markdown.
# other image: <image>\n<|grounding|>OCR this image.
# without layouts: <image>\nFree OCR.
# figures in document: <image>\nParse the figure.
# general: <image>\nDescribe this image in detail.
# rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
# '先天下之忧而忧'
# .......
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
import torch.nn as nn
import torch
import torch.nn.functional as F
import copy
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.projector_type == "identity":
modules = nn.Identity()
elif cfg.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
elif cfg.projector_type == "mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
channel_div = cfg.get("channel_div", 0.5)
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
modules = nn.Sequential(*modules)
self.high_layers = nn.Sequential(*modules)
self.low_layers = copy.deepcopy(modules)
else:
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
if cfg.get("token_pooling", False):
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
if cfg.get("conv_fusion_high_low_features", False):
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
self.layers = modules
def forward(self, x):
if self.cfg.get("token_pooling", False):
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# 在通道维度上拼接
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
# 通过线性层
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
if self.cfg.get("conv_fusion_high_low_features", False):
x = self.fusion_layer(x[:, 0]) + x[:, 1]
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
high_x = x[...,:self.cfg.input_dim[0]]
low_x = x[...,self.cfg.input_dim[0]:]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_layers(high_x)
low_x = self.low_layers(low_x)
x = torch.concat([high_x, low_x], dim=-1)
return x
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
@staticmethod
def get_flops_per_sample(cfg):
if cfg.projector_type == "linear":
fwd = 2 * cfg.input_dim * cfg.n_embed
elif "mlp_gelu" in cfg.projector_type :
mlp_depth = cfg.get("depth", 1)
downsample_ratio = cfg.get("downsample_ratio", 1)
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
input_dim = input_dim * downsample_ratio * downsample_ratio
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
else:
fwd = 0
return fwd * 3
from contextlib import nullcontext
import math
from typing import Optional, Tuple
# from megatron.model import LayerNorm
from easydict import EasyDict as adict
import torch
from torch.nn import functional as F
from torch import nn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# from optimus import flash_attn_func
# from megatron.core import tensor_parallel
# from megatron.core import parallel_state as mpu
# from megatron.core.utils import make_viewless_tensor, divide
# from megatron.model.fused_rms_norm import RMSNorm
# from megatron.model.transformer import (
# FlashSelfAttention,
# NoopTransformerLayer,
# _cfg_to_kwargs,
# )
# from megatron.model.enums import AttnMaskType, AttnType
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
# from megatron.model.utils import attention_mask_func
# from megatron.model.module import MegatronModule
# try:
# from einops import rearrange
# except ImportError:
# rearrange = None
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# try:
# # flash attention 2.x
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# except ImportError:
# try:
# # flash attention 1.x
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# except ImportError:
# flash_attn_unpadded_func = None
# try:
# from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
# except ImportError:
# flash_attn_unpadded_relative_attention_bias_func = None
# try:
# from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
# except ImportError:
# mask_flash_attn_unpadded_func = None
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
# print(tgt_size)
# print(abs_pos.shape)
# exit()
dim = abs_pos.size(-1)
# print(dim)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
# patch_embeds = self.patch_embedding(
# pixel_values
# ) # shape = [*, width, grid, grid]
if patch_embeds is not None:
patch_embeds = patch_embeds
# print(patch_embeds.shape)
else:
patch_embeds = self.patch_embedding(pixel_values)
# print(111111)
# shape = [*, width, grid, grid]
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
# x = torch.cat([cls_token, x], dim=1)
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
# embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
# class NoTPAttention(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.num_heads = cfg.num_attention_heads
# self.n_local_heads = cfg.num_attention_heads
# self.head_dim = cfg.hidden_size // cfg.num_attention_heads
# self.max_seq_len = cfg.seq_length
# self.use_flash_attention = cfg.use_flash_attn
# self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
# self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
# self.attn_drop = cfg.attention_dropout
# def forward(
# self,
# x: torch.Tensor,
# ):
# bsz, seqlen, _ = x.shape
# xqkv = self.qkv_proj(x)
# xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
# if self.use_flash_attention:
# output = flash_attn_qkvpacked_func(xqkv)
# output = output.view(bsz, seqlen, -1)
# else:
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # (B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
# output = self.out_proj(output)
# return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.n_local_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg.attention_dropout
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
output = flash_attn_qkvpacked_func(xqkv)
output = output.view(bsz, seqlen, -1)
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # (B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
else:
# output = flash_attn_qkvpacked_func(xqkv)
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# (B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# self.recompute_list = self.cfg.get("recompute_list", [])
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for lid, layer in enumerate(self.layers):
# if lid in self.recompute_list:
# def custom(layer_id):
# def custom_forward(*args, **kwargs):
# x_ = self.layers[layer_id](*args, **kwargs)
# return x_
# return custom_forward
# assert hidden_states.requires_grad == True, logger.warning(
# "When using recalculation, the input must have grad fn"
# )
# hidden_states = tensor_parallel.checkpoint(
# custom(lid),
# False,
# hidden_states.contiguous()
# )
# else:
hidden_states = layer(hidden_states)
return hidden_states
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
class VitModel(nn.Module):
def __init__(
self,
cfg,
freeze_embed=False,
freeze_pre_norm=False
) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
if freeze_embed:
for name, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
# self.pre_layrnorm = RMSNorm(
# cfg.hidden_size,
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
# sequence_parallel=False,
# use_fp32=True,
# use_optimus=True,
# )
if freeze_pre_norm:
for name, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(
self,
x,
patch_embeds
):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
# hidden_states, dis = local_dp_scatter(hidden_states)
output = self.transformer(hidden_states)
# output = local_dp_reduce(output, dis)
return output
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_heads = 16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
if __name__ == '__main__':
from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
sam_model = build_sam_vit_b()
vision_model = VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
# model = VitModel(1344)
# x = torch.zeros(2, 3, 224, 224)
x = torch.zeros(2, 3, 1024, 1024)
with torch.no_grad():
# y = vision_model(x)
patch_embed = sam_model(x)
print(patch_embed.shape)
y = vision_model(x, patch_embed)
print(y.shape)
image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
print(image_feature.shape)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
from flash_attn import flash_attn_qkvpacked_func
# from .common import LayerNorm2d, MLPBlock
# from mmgpt.model.vision_encoder.flash_4 import _attention_rel_h_rel_w
def get_abs_pos(abs_pos, tgt_size):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = x + get_abs_pos(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
neck_output = self.neck(x.permute(0, 3, 1, 2))
conv2_output = self.net_2(neck_output)
# print(f"conv2_output shape: {conv2_output.shape}")
conv3_output = self.net_3(conv2_output)
return conv3_output
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
# qkv = torch.stack([q, k, v], dim=1).transpose(1, 3).reshape(B, H * W, 3, self.num_heads, -1)
# x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False).transpose(1, 2)
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
dtype = rel_pos.dtype
rel_pos = rel_pos.to(torch.float32)
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
).to(dtype)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)
# print(state_dict.keys())
# for key in state_dict:
# image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
# ocr-anyting
# image_encoder.load_state_dict(state_dict, strict=True)
# tob
image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
print(checkpoint)
return image_encoder
\ No newline at end of file
"""Inference-only Deepseek-OCR model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
MlpProjectorConfig,
VisionEncoderConfig)
from process.image_process import (
DeepseekOCRProcessor, count_tiles)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
# from vllm.utils import is_list_of
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from deepencoder.sam_vary_sdpa import build_sam_vit_b
from deepencoder.clip_sdpa import build_clip_l
from deepencoder.build_linear import MlpProjector
from addict import Dict
# import time
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, PRINT_NUM_VIS_TOKENS, PROMPT
# The image token id may be various
_IMAGE_TOKEN = "<image>"
class DeepseekOCRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(DeepseekVLV2Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_num_image_tokens(self,
*,
image_width: int,
image_height: int,
cropping: bool = True) -> int:
hf_processor = self.get_hf_processor()
# image_size = hf_processor.image_size
# patch_size = hf_processor.patch_size
# downsample_ratio = hf_processor.downsample_ratio
image_size = IMAGE_SIZE
base_size = BASE_SIZE
patch_size = 16
downsample_ratio = 4
if CROP_MODE:
if image_width <= 640 and image_height <= 640:
crop_ratio = [1, 1]
else:
# images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image)
# find the closest aspect ratio to the target
crop_ratio = count_tiles(image_width, image_height, image_size=IMAGE_SIZE)
# print('===========')
# print('crop_ratio ', crop_ratio)
# print('============')
num_width_tiles, num_height_tiles = crop_ratio
else:
num_width_tiles = num_height_tiles = 1
h = w = math.ceil((base_size // patch_size) / downsample_ratio)
h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio)
global_views_tokens = h * (w + 1)
if num_width_tiles >1 or num_height_tiles>1:
local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1)
else:
local_views_tokens = 0
return global_views_tokens + local_views_tokens + 1
def get_image_size_with_most_features(self) -> ImageSize:
if IMAGE_SIZE == 1024 and BASE_SIZE == 1280:
return ImageSize(width=1024*2, height=1024*2)
return ImageSize(width=640*2, height=640*2)
class DeepseekOCRDummyInputsBuilder(
BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features()
if '<image>' in PROMPT:
return {
"image":
DeepseekOCRProcessor().tokenize_with_images(images = self._get_dummy_images(width=max_image_size.width,
height=max_image_size.height,
num_images=num_images), bos=True, eos=True, cropping=CROP_MODE)
}
else:
return {
"image": []
}
class DeepseekOCRMultiModalProcessor(
BaseMultiModalProcessor[DeepseekOCRProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# print(mm_data)
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
else:
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
# image_embeds=MultiModalFieldConfig.batched("image2"),
images_crop=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id
assert isinstance(image_token_id, int)
def get_replacement_deepseek_vl2(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
width = images[0][-1][0][0]
height = images[0][-1][0][1]
num_image_tokens = self.info.get_num_image_tokens(
image_width=width,
image_height=height,
# flag = True,
cropping=CROP_MODE,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_deepseek_vl2,
)
]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"language.": "language_model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# config.model_type ='deepseek_vl_v2'
self.config = config
self.multimodal_config = multimodal_config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
self.sam_model = build_sam_vit_b()
self.vision_model = build_clip_l()
n_embed = 1280
self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed))
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead")
# self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead")
# self.projector = torch.compile(self.projector, mode="max-autotune")
# special token for image token sequence format
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
if self.text_config.topk_method == "noaux_tc":
architectures = ["DeepseekV3ForCausalLM"]
elif not self.text_config.use_mla:
architectures = ["DeepseekForCausalLM"]
else:
architectures = ["DeepseekV2ForCausalLM"]
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language"),
architectures=architectures,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input(
self, **kwargs: object):
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}")
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError("Incorrect type of image crop. "
f"Got type: {type(images_crop)}")
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
# split the pixel and image_crop, all batch_size = 1
images_in_this_batch = []
# print(type(images_crop))
# print(pixel_values.shape)
with torch.no_grad():
for jdx in range(images_spatial_crop.size(0)):
# with torch.set_grad_enabled(False):
patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
if torch.sum(patches).item() != 0: # if all values = 0, no crop
# P, C, H, W = patches.shape
# crop_flag = 1
local_features_1 = self.sam_model(patches)
#TODO del patches
# torch.compiler.cudagraph_mark_step_begin()
local_features_2 = self.vision_model(patches, local_features_1)
local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
local_features = self.projector(local_features)
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
global_features = self.projector(global_features)
if PRINT_NUM_VIS_TOKENS:
print('=====================')
print('BASE: ', global_features.shape)
print('PATCHES: ', local_features.shape)
print('=====================')
_, hw, n_dim = global_features.shape
h = w = int(hw ** 0.5)
_2, hw2, n_dim2 = local_features.shape
h2 = w2 = int(hw2 ** 0.5)
width_crop_num, height_crop_num = crop_shape[0], crop_shape[1]
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
)
global_features = global_features.view(-1, n_dim)
local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2)
local_features = torch.cat(
[local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1
)
local_features = local_features.view(-1, n_dim2)
global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
else:
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
global_features = self.projector(global_features)
if PRINT_NUM_VIS_TOKENS:
print('=====================')
print('BASE: ', global_features.shape)
print('NO PATCHES')
print('=====================')
_, hw, n_dim = global_features.shape
h = w = int(hw ** 0.5)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
)
global_features = global_features.view(-1, n_dim)
global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
images_in_this_batch.append(global_local_features)
return images_in_this_batch
def _process_image_input(
self, image_input) -> torch.Tensor:
# image_input: [pixel_values, images_crop, images_spatial_crop]
pixel_values = image_input[0].to(torch.bfloat16)
# print(image_input[1][0].shape)
# print(type(image_input[1]))
# exit()
# images_crop = image_input[1].to(torch.bfloat16)
images_crop = image_input[1]
# images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
# local_start = time.time()
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values, images_crop = images_crop, images_spatial_crop=images_spatial_crop)
# local_total_time = time.time() - local_start
# print('encoder_time: ', local_total_time)
# exit()
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
# print(len(multimodal_embeddings))
# print(input_ids.shape)
# print(type(inputs_embeds))
# print(inputs_embeds.shape)
return inputs_embeds
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object):
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
processed_weights = []
for name, tensor in weights:
if 'sam_model' in name or 'vision_model' in name or 'projector' in name or 'image_newline' in name or 'view_seperator' in name:
new_name = name.replace('model.', '', 1)
else:
new_name = 'language.' + name
processed_weights.append((new_name, tensor))
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(processed_weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights
import math
from typing import List, Tuple
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
return target_aspect_ratio
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class ImageTransform:
def __init__(self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekOCRProcessor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast = TOKENIZER,
candidate_resolutions: Tuple[Tuple[int, int]] = [[1024, 1024]],
patch_size: int = 16,
downsample_ratio: int = 4,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<|▁pad▁|>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
# self.candidate_resolutions = candidate_resolutions # placeholder no use
self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE
# self.patch_size = patch_size
self.patch_size = 16
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
# self.downsample_ratio = downsample_ratio
self.downsample_ratio = 4
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
self.tokenizer = tokenizer
# self.tokenizer = add_special_token(tokenizer)
self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token})
# add image token
# image_token_id = self.tokenizer.vocab.get(image_token)
# if image_token_id is None:
# special_tokens = [image_token]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
# special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# special_tokens = ['<image>','<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>', '<td>', '</td>', '<tr>', '</tr>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
# # add special tokens for SFT data
# special_tokens = ["<|User|>", "<|Assistant|>"]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
# def select_best_resolution(self, image_size):
# # used for cropping
# original_width, original_height = image_size
# best_fit = None
# max_effective_resolution = 0
# min_wasted_resolution = float("inf")
# for width, height in self.candidate_resolutions:
# scale = min(width / original_width, height / original_height)
# downscaled_width, downscaled_height = int(
# original_width * scale), int(original_height * scale)
# effective_resolution = min(downscaled_width * downscaled_height,
# original_width * original_height)
# wasted_resolution = (width * height) - effective_resolution
# if effective_resolution > max_effective_resolution or (
# effective_resolution == max_effective_resolution
# and wasted_resolution < min_wasted_resolution):
# max_effective_resolution = effective_resolution
# min_wasted_resolution = wasted_resolution
# best_fit = (width, height)
# return best_fit
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0]
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"images_crop": images_crop,
"images_seq_mask": images_seq_mask,
"images_spatial_crop": images_spatial_crop,
"num_image_tokens": num_image_tokens,
}
# prepare = BatchFeature(
# data=dict(
# input_ids=input_ids,
# pixel_values=pixel_values,
# images_crop = images_crop,
# images_seq_mask=images_seq_mask,
# images_spatial_crop=images_spatial_crop,
# num_image_tokens=num_image_tokens,
# ),
# tensor_type="pt",
# )
# return prepare
def __call__(
self,
*,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt,
images=images,
inference_mode=inference_mode,
)
return prepare
def tokenize_with_images(
self,
# conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
# print(conversation)
conversation = PROMPT
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], []
image_shapes = []
num_image_tokens = []
tokenized_str = []
# print('image: ', len(images))
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
# if cropping:
# best_width, best_height = self.select_best_resolution(image.size)
# else:
# best_width, best_height = self.image_size, self.image_size
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
if cropping:
# print('image-size: ', image.size)
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
# print('image ', image.size)
# print('open_size:', image.size)
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE)
# print('crop_ratio: ', crop_ratio)
else:
# best_width, best_height = self.image_size, self.image_size
crop_ratio = [1, 1]
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
# print(crop_ratio)
"""process the global view"""
# if cropping
if self.image_size <= 640 and not cropping:
# print('directly resize')
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(image, (self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean))
images_list.append(self.image_transform(global_view))
"""record height / width crop num"""
# width_crop_num, height_crop_num = best_width // self.image_size, best_height // self.image_size
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
"""process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_crop_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
# """process the global view"""
# global_view = ImageOps.pad(image, (self.image_size, self.image_size),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# images_list.append(self.image_transform(global_view))
# """process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
# """add image tokens"""
"""add image tokens"""
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles) + [self.image_token_id]) * (
num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) |
(input_ids == self.image_token_id)] = self.ignore_id
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]]
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCRProcessor)
import torch
from transformers import LogitsProcessor
from transformers.generation.logits_process import _calc_banned_ngram_tokens
from typing import List, Set
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}")
self.ngram_size = ngram_size
self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
if len(input_ids) < self.ngram_size:
return scores
current_prefix = tuple(input_ids[-(self.ngram_size - 1):])
search_start = max(0, len(input_ids) - self.window_size)
search_end = len(input_ids) - self.ngram_size + 1
banned_tokens = set()
for i in range(search_start, search_end):
ngram = tuple(input_ids[i:i + self.ngram_size])
if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1])
banned_tokens = banned_tokens - self.whitelist_token_ids
if banned_tokens:
scores = scores.clone()
for token in banned_tokens:
scores[token] = -float("inf")
return scores
\ No newline at end of file
import os
import re
from tqdm import tqdm
import torch
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, MAX_CONCURRENCY, CROP_MODE, NUM_WORKERS
from concurrent.futures import ThreadPoolExecutor
import glob
from PIL import Image
from deepseek_ocr import DeepseekOCRForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs = MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=40, window_size=90, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
)
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
def clean_formula(text):
formula_pattern = r'\\\[(.*?)\\\]'
def process_formula(match):
formula = match.group(1)
formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula)
formula = formula.strip()
return r'\[' + formula + r'\]'
cleaned_text = re.sub(formula_pattern, process_formula, text)
return cleaned_text
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
# mathes_image = []
mathes_other = []
for a_match in matches:
mathes_other.append(a_match[0])
return matches, mathes_other
def process_single_image(image):
"""single image"""
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
}
return cache_item
if __name__ == "__main__":
# INPUT_PATH = OmniDocBench images path
os.makedirs(OUTPUT_PATH, exist_ok=True)
# print('image processing until processing prompts.....')
print(f'{Colors.RED}glob images.....{Colors.RESET}')
images_path = glob.glob(f'{INPUT_PATH}/*')
images = []
for image_path in images_path:
image = Image.open(image_path).convert('RGB')
images.append(image)
prompt = PROMPT
# batch_inputs = []
# for image in tqdm(images):
# prompt_in = prompt
# cache_list = [
# {
# "prompt": prompt_in,
# "multi_modal_data": {"image": Image.open(image).convert('RGB')},
# }
# ]
# batch_inputs.extend(cache_list)
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
for output, image in zip(outputs_list, images_path):
content = output.outputs[0].text
mmd_det_path = output_path + image.split('/')[-1].replace('.jpg', '_det.md')
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
afile.write(content)
content = clean_formula(content)
matches_ref, mathes_other = re_match(content)
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
content = content.replace(a_match_other, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n').replace('<center>', '').replace('</center>', '')
mmd_path = output_path + image.split('/')[-1].replace('.jpg', '.md')
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(content)
import asyncio
import re
import os
import torch
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry
import time
from deepseek_ocr import DeepseekOCRForCausalLM
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from tqdm import tqdm
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path):
try:
image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image)
return corrected_image
except Exception as e:
print(f"error: {e}")
try:
return Image.open(image_path)
except:
return None
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts):
result_image = draw_bounding_boxes(image, ref_texts)
return result_image
async def stream_generate(image=None, prompt=''):
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
# ignore_eos=False,
)
request_id = f"request-{int(time.time())}"
printed_length = 0
if image and '<image>' in prompt:
request = {
"prompt": prompt,
"multi_modal_data": {"image": image}
}
elif prompt:
request = {
"prompt": prompt
}
else:
assert False, f'prompt is none!!!'
async for request_output in engine.generate(
request, sampling_params, request_id
):
if request_output.outputs:
full_text = request_output.outputs[0].text
new_text = full_text[printed_length:]
print(new_text, end='', flush=True)
printed_length = len(full_text)
final_output = full_text
print('\n')
return final_output
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
image = load_image(INPUT_PATH).convert('RGB')
if '<image>' in PROMPT:
image_features = DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)
else:
image_features = ''
prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1
if save_results and '<image>' in prompt:
print('='*15 + 'save results:' + '='*15)
image_draw = image.copy()
outputs = result_out
with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
matches_ref, matches_images, mathes_other = re_match(outputs)
# print(matches_ref)
result = process_image_with_refs(image_draw, matches_ref)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(a_match_image, f'![](images/' + str(idx) + '.jpg)\n')
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
# if 'structural formula' in conversation[0]['content']:
# outputs = '<smiles>' + outputs + '</smiles>'
with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
if 'line_type' in outputs:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
lines = eval(outputs)['Line']['line']
line_type = eval(outputs)['Line']['line_type']
# print(lines)
endpoints = eval(outputs)['Line']['line_endpoint']
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
for idx, line in enumerate(lines):
try:
p0 = eval(line.split(' -- ')[0])
p1 = eval(line.split(' -- ')[-1])
if line_type[idx] == '--':
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
else:
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
ax.scatter(p0[0], p0[1], s=5, color = 'k')
ax.scatter(p1[0], p1[1], s=5, color = 'k')
except:
pass
for endpoint in endpoints:
label = endpoint.split(': ')[0]
(x, y) = eval(endpoint.split(': ')[1])
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
fontsize=5, fontweight='light')
try:
if 'Circle' in eval(outputs).keys():
circle_centers = eval(outputs)['Circle']['circle_center']
radius = eval(outputs)['Circle']['radius']
for center, r in zip(circle_centers, radius):
center = eval(center.split(': ')[1])
circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8)
ax.add_patch(circle)
except:
pass
plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
plt.close()
result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg')
import os
import fitz
import img2pdf
import io
import re
from tqdm import tqdm
import torch
from concurrent.futures import ThreadPoolExecutor
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
"""
pdf2images
"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
if image_format.upper() == "PNG":
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
else:
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ('RGBA', 'LA'):
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != 'RGB':
img = img.convert('RGB')
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes)
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"error: {e}")
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx):
result_image = draw_bounding_boxes(image, ref_texts, jdx)
return result_image
def process_single_image(image):
"""single image"""
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
}
return cache_item
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
images = pdf_to_images_high_quality(INPUT_PATH)
prompt = PROMPT
# batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
# for image in tqdm(images):
# prompt_in = prompt
# cache_list = [
# {
# "prompt": prompt_in,
# "multi_modal_data": {"image": DeepseekOCRProcessor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
# }
# ]
# batch_inputs.extend(cache_list)
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
contents_det = ''
contents = ''
draw_images = []
jdx = 0
for output, img in zip(outputs_list, images):
content = output.outputs[0].text
if '<|end▁of▁sentence|>' in content: # repeat no eos
content = content.replace('<|end▁of▁sentence|>', '')
else:
if SKIP_REPEAT:
continue
page_num = f'\n<--- Page Split --->'
contents_det += content + f'\n{page_num}\n'
image_draw = img.copy()
matches_ref, matches_images, mathes_other = re_match(content)
# print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n')
for idx, a_match_other in enumerate(mathes_other):
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
contents += content + f'\n{page_num}\n'
jdx += 1
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
afile.write(contents_det)
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(contents)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# DeepSeek-OCR_pytorch # DeepSeek-OCR
## 论文
[DeepSeek_OCR](./DeepSeek_OCR_paper.pdf)
DeepSeek 推出了全新的视觉文本压缩模型 DeepSeek-OCR ## 模型结构
\ No newline at end of file DeepSeek-OCR 由一个深度编码器和一个 DeepSeek-3B-MoE 解码器组成。深度编码器是 DeepSeek-OCR 的核心,它由三个部分构成:一个以窗口注意力为主的感知组件 SAM、一个具有密集全局注意力的用于知识的 CLIP 组件,以及一个连接它们的 16 倍token压缩器。
<div align=center>
<img src="./doc/model.png"/>
</div>
## 算法原理
DeepSeek 推出了全新的视觉文本压缩模型 DeepSeek-OCR。DeepSeek-OCR 基于 DeepSeek-MoE-VLM 架构,采用了混合专家(MoE)设计,在保持模型小巧的同时实现了强大的功能。
DeepSeek-OCR 的能力范围包括:
- 复杂图表解析(折线图、柱状图等数据可视化)
- 文档格式保留(标题、段落、列表等结构信息)
- 多语言处理(中英文混合识别)
- 物体定位(grounding 功能支持)
<div align=center>
<img src="./doc/xxx.png"/>
</div>
## 环境配置
### 硬件需求
DCU型号:K100AI,节点数量:1台,卡数:1张。
`-v 路径``docker_name``imageID`根据实际情况修改
### Docker(方法一)
```bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2-ds3.2
docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash
cd /your_code_path/deepseek-ocr_pytorch
```
### Dockerfile(方法二)
```bash
cd docker
docker build --no-cache -t deepseek-ocr:latest .
docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash
cd /your_code_path/deepseek-ocr_pytorch
```
### Anaconda(方法三)
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
```bash
DTK: 25.04.1
python: 3.10.12
torch: 2.5.1+das.opt1.dtk25041
transformers: 4.46.3
vllm: 0.9.2
```
`Tips:以上dtk驱动、pytorch等DCU相关工具版本需要严格一一对应`, 其它非深度学习库参照requirements.txt安装:
```bash
pip install -r requirements.txt
```
## 数据集
暂无
## 训练
暂无
## 推理
### transformers
模型地址,测试图片路径,输出路径根据实际情况修改。
```bash
cd DeepSeek-OCR-master/DeepSeek-OCR-hf
python run_dpsk_ocr.py --model_name_or_path=deepseek-ai/DeepSeek-OCR --image_path=./doc/test.jpg --output_path=./output
```
### vllm
```bash
cd DeepSeek-OCR-master/DeepSeek-OCR-vllm
# image: streaming output
python run_dpsk_ocr_image.py
# pdf
python run_dpsk_ocr_pdf.py
```
## result
<div align=center>
<img src="./doc/xxx.png"/>
</div>
### 精度
DCU与GPU精度一致,推理框架:vllm。
## 应用场景
### 算法类别
OCR
### 热点应用行业
`制造,金融,交通,教育,医疗`
## 预训练权重
- [DeepSeek-OCR](https://huggingface.co/deepseek-ai/DeepSeek-OCR)
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/deepseek-ocr_pytorch
## 参考资料
- https://github.com/deepseek-ai/DeepSeek-OCR
FROM image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2-ds3.2
\ No newline at end of file
icon.png

61 KB

Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment