"vscode:/vscode.git/clone" did not exist on "faf97b39b32e3b03dc0d2de6a3be8946df98bd34"
Commit 6ad287f7 authored by liuxu3's avatar liuxu3
Browse files

added DeepSeek OCR API by liushengtong

parent 80c11a03
# DeepSeek OCR 服务配置
# 请根据你的实际情况修改以下配置
# 模型路径(必需)
MODEL_PATH=/home/lst/deepseek_ocr
# GPU 配置
GPU_ID=0
# 服务端口
PORT=8002
# CPU 工作线程数
CPU_WORKERS=2
# Conda 环境配置
# 方式1: 使用 conda 环境名称(需要 conda 命令)
#CONDA_ENV_NAME=dso
# 方式2: 直接使用 miniconda 环境路径(推荐用于 miniconda)
PYTHON_PATH=/usr/bin/python3
# PYTHON 版本(仅在使用 conda 环境名称时需要)
PYTHON_VERSION=3.10
# 服务配置
HOST=0.0.0.0
# DeepSeek OCR 服务配置示例
# 复制此文件为 .env 并修改以下配置
# ============================================================================
# 核心配置(必须修改)
# ============================================================================
# 模型路径 - 请修改为你的实际模型路径
MODEL_PATH=/path/to/your/DeepSeek-OCR
# ============================================================================
# GPU 和服务配置(可选)
# ============================================================================
# GPU ID(默认使用 GPU 3)
GPU_ID=3
# 服务端口(默认 8708)
PORT=8708
# 服务监听地址(默认 0.0.0.0,允许外部访问)
HOST=0.0.0.0
# CPU 工作线程数(默认 2,用于图像预处理)
CPU_WORKERS=2
# ============================================================================
# Python 环境配置(二选一)
# ============================================================================
# 方式1: 使用 conda 环境名称(需要 conda 命令)
# CONDA_ENV_NAME=deepseek-ocr-vllm
# 方式2: 直接使用 miniconda 环境路径(推荐用于 miniconda)
# PYTHON_PATH=/home/data/nongwa/miniconda3/envs/your_env/bin/python
# Python 版本(仅在使用 conda 环境名称时需要)
PYTHON_VERSION=3.10
\ No newline at end of file
# TODO: change modes
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
MIN_CROPS= 2
MAX_CROPS= 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = True
MODEL_PATH = '/home/lst/deepseek_ocr' # change to your model path
# TODO: change INPUT_PATH
# .pdf: run_dpsk_ocr_pdf.py;
# .jpg, .png, .jpeg: run_dpsk_ocr_image.py;
# Omnidocbench images path: run_dpsk_ocr_eval_batch.py
INPUT_PATH = './use.pdf'
OUTPUT_PATH = './output'
PROMPT = '<image>\n<|grounding|>Convert the document to markdown.'
# PROMPT = '<image>\nFree OCR.'
# TODO commonly used prompts
# document: <image>\n<|grounding|>Convert the document to markdown.
# other image: <image>\n<|grounding|>OCR this image.
# without layouts: <image>\nFree OCR.
# figures in document: <image>\nParse the figure.
# general: <image>\nDescribe this image in detail.
# rec: <image>\nLocate <|ref|>xxxx<|/ref|> in the image.
# '先天下之忧而忧'
# .......
from transformers import AutoTokenizer
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
import torch.nn as nn
import torch
import torch.nn.functional as F
import copy
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.projector_type == "identity":
modules = nn.Identity()
elif cfg.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
elif cfg.projector_type == "mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [
nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "downsample_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
mlp_ratio = cfg.get("mlp_ratio", 1)
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
channel_div = cfg.get("channel_div", 0.5)
self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
modules = nn.Sequential(*modules)
self.high_layers = nn.Sequential(*modules)
self.low_layers = copy.deepcopy(modules)
else:
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
if cfg.get("token_pooling", False):
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
if cfg.get("conv_fusion_high_low_features", False):
self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
self.layers = modules
def forward(self, x):
if self.cfg.get("token_pooling", False):
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# 在通道维度上拼接
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
# 通过线性层
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
if self.cfg.get("conv_fusion_high_low_features", False):
x = self.fusion_layer(x[:, 0]) + x[:, 1]
if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
high_x = x[...,:self.cfg.input_dim[0]]
low_x = x[...,self.cfg.input_dim[0]:]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.cfg.projector_type == 'low_high_split_mlp_gelu':
high_x, low_x = x[0], x[1]
high_x = self.high_layers(high_x)
low_x = self.low_layers(low_x)
x = torch.concat([high_x, low_x], dim=-1)
return x
if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
@staticmethod
def get_flops_per_sample(cfg):
if cfg.projector_type == "linear":
fwd = 2 * cfg.input_dim * cfg.n_embed
elif "mlp_gelu" in cfg.projector_type :
mlp_depth = cfg.get("depth", 1)
downsample_ratio = cfg.get("downsample_ratio", 1)
input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
input_dim = input_dim * downsample_ratio * downsample_ratio
fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
else:
fwd = 0
return fwd * 3
from contextlib import nullcontext
import math
from typing import Optional, Tuple
# from megatron.model import LayerNorm
from easydict import EasyDict as adict
import torch
from torch.nn import functional as F
from torch import nn
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# from optimus import flash_attn_func
# from megatron.core import tensor_parallel
# from megatron.core import parallel_state as mpu
# from megatron.core.utils import make_viewless_tensor, divide
# from megatron.model.fused_rms_norm import RMSNorm
# from megatron.model.transformer import (
# FlashSelfAttention,
# NoopTransformerLayer,
# _cfg_to_kwargs,
# )
# from megatron.model.enums import AttnMaskType, AttnType
# from megatron.model.fused_softmax import FusedScaleMaskSoftmax
# from megatron.model.utils import attention_mask_func
# from megatron.model.module import MegatronModule
# try:
# from einops import rearrange
# except ImportError:
# rearrange = None
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# try:
# # flash attention 2.x
# from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
# except ImportError:
# try:
# # flash attention 1.x
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# except ImportError:
# flash_attn_unpadded_func = None
# try:
# from flash_attn.flash_attn_interface import flash_attn_unpadded_relative_attention_bias_func
# except ImportError:
# flash_attn_unpadded_relative_attention_bias_func = None
# try:
# from flash_attn.flash_attn_interface import mask_flash_attn_unpadded_func
# except ImportError:
# mask_flash_attn_unpadded_func = None
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
# print(tgt_size)
# print(abs_pos.shape)
# exit()
dim = abs_pos.size(-1)
# print(dim)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1,
2).contiguous()
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
# patch_embeds = self.patch_embedding(
# pixel_values
# ) # shape = [*, width, grid, grid]
if patch_embeds is not None:
patch_embeds = patch_embeds
# print(patch_embeds.shape)
else:
patch_embeds = self.patch_embedding(pixel_values)
# print(111111)
# shape = [*, width, grid, grid]
# patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
# x = torch.cat([cls_token, x], dim=1)
embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1))
# embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
# class NoTPAttention(nn.Module):
# def __init__(self, cfg):
# super().__init__()
# self.num_heads = cfg.num_attention_heads
# self.n_local_heads = cfg.num_attention_heads
# self.head_dim = cfg.hidden_size // cfg.num_attention_heads
# self.max_seq_len = cfg.seq_length
# self.use_flash_attention = cfg.use_flash_attn
# self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
# self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
# self.attn_drop = cfg.attention_dropout
# def forward(
# self,
# x: torch.Tensor,
# ):
# bsz, seqlen, _ = x.shape
# xqkv = self.qkv_proj(x)
# xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
# if self.use_flash_attention:
# output = flash_attn_qkvpacked_func(xqkv)
# output = output.view(bsz, seqlen, -1)
# else:
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # (B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# utput = output.permute(0, 2, 1, 3).view(bsz, seqlen, -1)
# output = self.out_proj(output)
# return output
# from optimus.flash_attn_interface import flash_attn_qkvpacked_func
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.n_local_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True)
self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg.attention_dropout
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
output = flash_attn_qkvpacked_func(xqkv)
output = output.view(bsz, seqlen, -1)
# xq, xk, xv = torch.split(xqkv, 1, dim=2)
# xq = xq.squeeze(2)
# xk = xk.squeeze(2)
# xv = xv.squeeze(2)
# # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# # (B, num_head, S, head_size)
# xq = xq.permute(0, 2, 1, 3)
# xk = xk.permute(0, 2, 1, 3)
# xv = xv.permute(0, 2, 1, 3)
# # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
# output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
# output = output.permute(0, 2, 1, 3).contiguous().view(bsz, seqlen, -1)
else:
# output = flash_attn_qkvpacked_func(xqkv)
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# (B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg.hidden_size, eps=cfg.layernorm_epsilon
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# self.recompute_list = self.cfg.get("recompute_list", [])
self.num_layers = cfg.num_layers # _get_num_layers(cfg)
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for lid, layer in enumerate(self.layers):
# if lid in self.recompute_list:
# def custom(layer_id):
# def custom_forward(*args, **kwargs):
# x_ = self.layers[layer_id](*args, **kwargs)
# return x_
# return custom_forward
# assert hidden_states.requires_grad == True, logger.warning(
# "When using recalculation, the input must have grad fn"
# )
# hidden_states = tensor_parallel.checkpoint(
# custom(lid),
# False,
# hidden_states.contiguous()
# )
# else:
hidden_states = layer(hidden_states)
return hidden_states
# from megatron.core.tensor_parallel.layers import non_tensor_paralleled, local_dp_reduce, local_dp_scatter
class VitModel(nn.Module):
def __init__(
self,
cfg,
freeze_embed=False,
freeze_pre_norm=False
) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size)
if freeze_embed:
for name, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg.hidden_size,
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
# self.pre_layrnorm = RMSNorm(
# cfg.hidden_size,
# eps=cfg.get("pre_layernorm_epsilon", 1e-5),
# sequence_parallel=False,
# use_fp32=True,
# use_optimus=True,
# )
if freeze_pre_norm:
for name, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(
self,
x,
patch_embeds
):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
# hidden_states, dis = local_dp_scatter(hidden_states)
output = self.transformer(hidden_states)
# output = local_dp_reduce(output, dis)
return output
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_heads = 16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
if __name__ == '__main__':
from mmgpt.model.vision_encoder.sam_b import build_sam_vit_b
vit_model_cfg = adict(
num_layers=24,
hidden_size=1024,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout = 0.0,
attention_dropout = 0.0,
no_persist_layer_norm = False,
layernorm_epsilon = 1e-5,
pre_layernorm_epsilon = 1e-5,
image_size = 224,
patch_size = 14,
recompute_list = []
)
sam_model = build_sam_vit_b()
vision_model = VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
# model = VitModel(1344)
# x = torch.zeros(2, 3, 224, 224)
x = torch.zeros(2, 3, 1024, 1024)
with torch.no_grad():
# y = vision_model(x)
patch_embed = sam_model(x)
print(patch_embed.shape)
y = vision_model(x, patch_embed)
print(y.shape)
image_feature = torch.add(y[:, 1:], patch_embed.flatten(2).permute(0, 2, 1))
print(image_feature.shape)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from functools import partial
from flash_attn import flash_attn_qkvpacked_func
# from .common import LayerNorm2d, MLPBlock
# from mmgpt.model.vision_encoder.flash_4 import _attention_rel_h_rel_w
def get_abs_pos(abs_pos, tgt_size):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode='bicubic',
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
# x = x + self.pos_embed
x = x + get_abs_pos(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
neck_output = self.neck(x.permute(0, 3, 1, 2))
conv2_output = self.net_2(neck_output)
# print(f"conv2_output shape: {conv2_output.shape}")
conv3_output = self.net_3(conv2_output)
return conv3_output
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
# qkv = torch.stack([q, k, v], dim=1).transpose(1, 3).reshape(B, H * W, 3, self.num_heads, -1)
# x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False).transpose(1, 2)
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
dtype = rel_pos.dtype
rel_pos = rel_pos.to(torch.float32)
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
).to(dtype)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
if checkpoint is not None:
# with open(checkpoint, "rb") as f:
state_dict = torch.load(checkpoint)
# print(state_dict.keys())
# for key in state_dict:
# image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
# ocr-anyting
# image_encoder.load_state_dict(state_dict, strict=True)
# tob
image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
print(checkpoint)
return image_encoder
\ No newline at end of file
This diff is collapsed.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DeepSeek OCR API Server (vLLM) - 极简版 + 优化版
"""
import os
import io
import re
import argparse
import asyncio
from io import BytesIO
from typing import List
from concurrent.futures import ThreadPoolExecutor
import torch
from PIL import Image
try:
import fitz
except Exception:
fitz = None
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from vllm import LLM, SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from deepseek_ocr import DeepseekOCRForCausalLM
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
app = FastAPI(title="DeepSeek OCR API (vLLM) - Optimized", version="2.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
llm = None
cpu_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="CPU-Worker")
gpu_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="GPU-Worker")
vllm_lock = asyncio.Lock()
PROMPT_OCR = "<image>\n<|grounding|>Convert the document to markdown."
PROMPT_DESC = "<image>\nDescribe this image in detail."
# -----------------------
# Monkey Patch
# -----------------------
_original_tokenize = DeepseekOCRProcessor.tokenize_with_images
def _patched_tokenize(self, images, bos=True, eos=True, cropping=True, prompt=None):
if prompt is not None:
import config
old = config.PROMPT
config.PROMPT = prompt
try:
return _original_tokenize(self, images, bos, eos, cropping)
finally:
config.PROMPT = old
return _original_tokenize(self, images, bos, eos, cropping)
DeepseekOCRProcessor.tokenize_with_images = _patched_tokenize
def pdf_to_images_sync(pdf_bytes: bytes, dpi: int = 144) -> List[Image.Image]:
"""PDF 转图片 """
if fitz is None:
raise RuntimeError("Please install PyMuPDF")
images = []
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
matrix = fitz.Matrix(dpi / 72.0, dpi / 72.0)
for page in doc:
pix = page.get_pixmap(matrix=matrix, alpha=False)
img = Image.open(io.BytesIO(pix.tobytes("png")))
if img.mode != "RGB":
if img.mode in ("RGBA", "LA"):
bg = Image.new("RGB", img.size, (255, 255, 255))
bg.paste(img, mask=img.split()[-1])
img = bg
else:
img = img.convert("RGB")
images.append(img)
doc.close()
return images
def image_open_sync(image_bytes: bytes) -> Image.Image:
"""打开图片 (同步版本)"""
return Image.open(BytesIO(image_bytes)).convert("RGB")
def clear_vllm_cache_sync():
"""清理 vLLM 缓存 (同步版本)"""
if llm is None:
return
try:
if hasattr(llm.llm_engine, 'input_preprocessor'):
prep = llm.llm_engine.input_preprocessor
if hasattr(prep, '_mm_processor_cache'):
prep._mm_processor_cache.clear()
except:
pass
def tokenize_image_sync(image: Image.Image, prompt: str):
"""
图像 tokenize (同步版本, CPU 密集)
WARNING: 这是最大的优化点!
"""
processor = DeepseekOCRProcessor()
return processor.tokenize_with_images(images=[image], prompt=prompt)
def vllm_generate_sync(tokenized, prompt: str) -> str:
"""
vLLM 推理 (同步版本, GPU 密集)
注意: tokenized 已经在 CPU 线程池完成
"""
batch_inputs = [{
"prompt": prompt,
"multi_modal_data": {"image": tokenized}
}]
if prompt == PROMPT_OCR:
logits_proc = [NoRepeatNGramLogitsProcessor(20, 50, {128821, 128822})]
params = SamplingParams(
temperature=0.0,
max_tokens=8192,
skip_special_tokens=False,
logits_processors=logits_proc,
repetition_penalty=1.05,
)
else:
params = SamplingParams(
temperature=0.0,
max_tokens=8192,
skip_special_tokens=False,
)
outputs = llm.generate(batch_inputs, params)
return outputs[0].outputs[0].text
def clean_markdown_sync(text: str) -> str:
"""清理 Markdown (同步版本)"""
text = re.sub(r'<\|ref\|>.*?<\|/ref\|>', '', text)
text = re.sub(r'<\|det\|>.*?<\|/det\|>', '', text)
text = re.sub(r'<\|.*?\|>', '', text)
text = re.sub(r'\[\[.*?\]\]', '', text)
text = re.sub(r'={50,}.*?={50,}', '', text, flags=re.DOTALL)
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
async def pdf_to_images_async(pdf_bytes: bytes, dpi: int = 144) -> List[Image.Image]:
"""PDF 转图片 (异步)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(cpu_executor, pdf_to_images_sync, pdf_bytes, dpi)
async def image_open_async(image_bytes: bytes) -> Image.Image:
"""打开图片 (异步)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(cpu_executor, image_open_sync, image_bytes)
async def tokenize_image_async(image: Image.Image, prompt: str):
"""
图像 tokenize (异步)
NOTE: 关键优化: 在 CPU 线程池执行
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(cpu_executor, tokenize_image_sync, image, prompt)
async def vllm_generate_async(image: Image.Image, prompt: str) -> str:
"""
完整的 vLLM 推理流程 (异步)
优化: 分离 tokenize (CPU) 和 generate (GPU)
"""
# 步骤1: tokenize (CPU 密集, 在 CPU 线程池执行)
tokenized = await tokenize_image_async(image, prompt)
# 步骤2: GPU 推理 (GPU 密集, 在 GPU 线程池执行, 有锁保护)
async with vllm_lock:
# 清理缓存 (在 GPU 线程池执行)
loop = asyncio.get_event_loop()
await loop.run_in_executor(gpu_executor, clear_vllm_cache_sync)
# GPU 推理
result = await loop.run_in_executor(
gpu_executor,
vllm_generate_sync,
tokenized,
prompt
)
return result
async def clean_markdown_async(text: str) -> str:
"""清理 Markdown (异步)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(cpu_executor, clean_markdown_sync, text)
async def generate_image_description_async(image: Image.Image) -> str:
"""生成图片描述 (异步)"""
try:
# GPU 推理
result = await vllm_generate_async(image, PROMPT_DESC)
# CPU 后处理
loop = asyncio.get_event_loop()
def process_desc(text):
desc = re.sub(r'<\|ref\|>.*?<\|/ref\|>', '', text)
desc = re.sub(r'<\|det\|>.*?<\|/det\|>', '', desc)
desc = re.sub(r'<\|.*?\|>', '', desc)
desc = re.sub(r'\[\[.*?\]\]', '', desc)
desc = re.sub(r'\s+', ' ', desc).strip()
if len(desc) > 200:
cutoff = desc[:200].rfind('.')
if cutoff > 100:
desc = desc[:cutoff + 1]
else:
desc = desc[:200].rsplit(' ', 1)[0] + '...'
return desc
desc = await loop.run_in_executor(cpu_executor, process_desc, result)
return desc
except Exception as e:
print(f"WARNING: 图片描述失败: {e}")
return ""
# -----------------------
# 模型初始化
# -----------------------
def initialize_model(model_path: str, gpu_id: int):
global llm
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
os.environ['VLLM_USE_V1'] = '0'
print(f"[INFO] 加载模型: {model_path}")
llm = LLM(
model=model_path,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
block_size=64,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
max_num_seqs=20,
disable_mm_preprocessor_cache=True,
)
print("[SUCCESS] 模型加载完成")
#print(f"[INFO] 线程池配置:")
#print(f" - CPU 线程池: {cpu_executor._max_workers} 线程")
#print(f" - GPU 线程池: {gpu_executor._max_workers} 线程")
# -----------------------
# API 路由
# -----------------------
@app.get("/")
async def root():
return {
"service": "DeepSeek OCR (vLLM) - Optimized",
"version": "2.0.0",
"status": "running"
}
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_ready": llm is not None,
"cpu_workers": cpu_executor._max_workers,
"gpu_workers": gpu_executor._max_workers,
}
async def vllm_generate_batch_async(
images: List[Image.Image],
prompt: str,
show_progress: bool = True
) -> List[str]:
"""
批量 vLLM 推理 - 真正的批处理优化
Args:
images: 图片列表
prompt: 提示词
show_progress: 是否显示进度
Returns:
生成的文本列表
"""
total = len(images)
# 步骤1: 并发 tokenize
# 标准化图片 -> Vision Encoder (ViT) -> 图像特征向量 (例如:[196, 1024] - 196个位置,每个1024维)
if show_progress:
print(f" [1/3] Tokenize {total} 页...")
tokenize_tasks = [tokenize_image_async(img, prompt) for img in images]
all_tokenized = await asyncio.gather(*tokenize_tasks)
if show_progress:
print(f" [1/3] Tokenize 完成")
# 步骤2: 构造批量输入
batch_inputs = [
{
"prompt": prompt,
"multi_modal_data": {"image": tok}
}
for tok in all_tokenized
]
# 步骤3: 批量 GPU 推理
async with vllm_lock:
if show_progress:
print(f" [2/3] GPU 批量推理 {total} 页...")
loop = asyncio.get_event_loop()
# 清理缓存
await loop.run_in_executor(gpu_executor, clear_vllm_cache_sync)
# 批量推理
def batch_generate():
# 根据 prompt 类型选择参数
if prompt == PROMPT_OCR:
logits_proc = [NoRepeatNGramLogitsProcessor(20, 50, {128821, 128822})]
params = SamplingParams(
temperature=0.0,
max_tokens=8192,
skip_special_tokens=False,
logits_processors=logits_proc,
repetition_penalty=1.05,
)
else:
params = SamplingParams(
temperature=0.0,
max_tokens=8192,
skip_special_tokens=False,
)
# NOTE: 关键: 批量调用
outputs = llm.generate(batch_inputs, params)
return [out.outputs[0].text for out in outputs]
results = await loop.run_in_executor(gpu_executor, batch_generate)
if show_progress:
print(f" [2/3] GPU 推理完成")
return results
@app.post("/ocr")
async def ocr(
file: UploadFile = File(...),
enable_description: bool = Form(False),
):
"""OCmR 接口 (批量处理)"""
if llm is None:
raise HTTPException(503, "模型未加载")
import time
start_time = time.time()
try:
# 1. 读取文件
contents = await file.read()
t1 = time.time()
# 2. 解析文件
if file.filename.lower().endswith('.pdf'):
# 如果是PDF,则转换为图片列表
images = await pdf_to_images_async(contents)
else:
# 如果是图片,则直接打开
images = [await image_open_async(contents)]
t2 = time.time()
# 3. 批量 OCR
raw_results = await vllm_generate_batch_async(images, PROMPT_OCR)
t3 = time.time()
print(f" OCR 耗时: {t3 - t2:.2f}s")
# 4. 后处理
print(f" [3/3] 后处理...")
async def postprocess(idx: int, raw: str, img: Image.Image) -> str:
# 图片描述
if enable_description:
img_pattern = r'<\|ref\|>image<\|/ref\|><\|det\|>\[\[.*?\]\]<\|/det\|>'
matches = list(re.finditer(img_pattern, raw))
for match in matches:
desc = await generate_image_description_async(img)
replacement = f"[图片: {desc}]" if desc else "[图片]"
raw = raw.replace(match.group(0), replacement)
# 清理 Markdown
cleaned = await clean_markdown_async(raw)
return cleaned if cleaned else ""
tasks = [postprocess(i, raw, img) for i, (raw, img) in enumerate(zip(raw_results, images))]
md_parts = await asyncio.gather(*tasks)
t4 = time.time()
print(f" [3/3] 后处理完成 ({t4 - t3:.2f}s)")
# 5. 合并结果
final_md = "\n\n".join([md for md in md_parts if md])
total_time = time.time() - start_time
print(f"{'='*60}")
print(f"[SUCCESS] 全部完成")
print(f" 总耗时: {total_time:.2f}s")
print(f" 平均: {total_time / len(images):.2f}s/页")
print(f"{'='*60}\n")
return JSONResponse({
"markdown": final_md,
"page_count": len(images),
"processing_time": round(total_time, 2),
})
except Exception as e:
import traceback
print(f"[ERROR] 处理失败: {e}")
print(traceback.format_exc())
raise HTTPException(500, f"处理失败: {e}")
# -----------------------
# 优雅关闭
# -----------------------
@app.on_event("shutdown")
async def shutdown_event():
print("[INFO] 关闭线程池...")
cpu_executor.shutdown(wait=True)
gpu_executor.shutdown(wait=True)
print("[SUCCESS] 线程池已关闭")
# -----------------------
# 启动
# -----------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", required=True, help="模型路径")
parser.add_argument("--gpu-id", type=int, default=0, help="GPU ID")
parser.add_argument("--port", type=int, default=8002, help="端口")
parser.add_argument("--host", default="0.0.0.0", help="监听地址")
parser.add_argument("--cpu-workers", type=int, default=2, help="CPU 线程池大小")
args = parser.parse_args()
# 更新线程池大小
global cpu_executor
cpu_executor = ThreadPoolExecutor(
max_workers=args.cpu_workers,
thread_name_prefix="CPU-Worker"
)
initialize_model(args.model_path, args.gpu_id)
print(f"\n[INFO] 服务启动: http://{args.host}:{args.port}")
print(f"[INFO] 接口文档: http://{args.host}:{args.port}/docs\n")
uvicorn.run(app, host=args.host, port=args.port, workers=1)
if __name__ == "__main__":
main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment