Commit 6ad287f7 authored by liuxu3's avatar liuxu3
Browse files

added DeepSeek OCR API by liushengtong

parent 80c11a03
import requests
from pathlib import Path
from config import INPUT_PATH,OUTPUT_PATH
import os
def ocr_pdf(pdf_path, server_url="http://localhost:8001", save_result=True):
"""
对 PDF 文档进行 OCR 识别
参数:
pdf_path: PDF 文件路径
server_url: OCR 服务地址
save_result: 是否保存识别结果到文件
返回:
dict: 包含识别结果的字典
"""
# 1. 检查文件
pdf_file = Path(pdf_path)
if not pdf_file.exists():
raise FileNotFoundError(f"PDF文件不存在: {pdf_path}")
if pdf_file.suffix.lower() != '.pdf':
raise ValueError(f"不是 PDF 文件: {pdf_path}")
# 2. 读取文件大小(用于显示)
file_size_mb = pdf_file.stat().st_size / (1024 * 1024)
print(f"文件名: {pdf_file.name}")
print(f"文件大小: {file_size_mb:.2f} MB")
print(f"开始处理...")
# 3. 准备请求
api_url = f"{server_url}/ocr"
# 4. 发送请求
with open(pdf_path, 'rb') as f:
files = {'file': (pdf_file.name, f, 'application/pdf')}
# 这里可以添加额外参数
# data = {'enable_description': True} # 启用图片描述(会增加处理时间)
response = requests.post(api_url, files=files)
# 5. 处理结果
if response.status_code == 200:
result = response.json()
print(f"处理完成!")
print(f"统计信息:")
print(f" - 总页数: {result['page_count']} 页")
print(f" - 处理耗时: {result['processing_time']:.2f} 秒")
print(f" - 平均速度: {result['processing_time'] / result['page_count']:.2f} 秒/页")
# 6. 保存结果到文件
if save_result:
os.makedirs(OUTPUT_PATH, exist_ok=True)
output_file = pdf_file.with_suffix('.md')
file_path = os.path.join(OUTPUT_PATH, output_file)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(result['markdown'])
print(f"结果已保存到: {output_file}")
return result
else:
print(f"处理失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
# 使用示例
# 替换为你的实际 PDF 路径
#pdf_file = "./doc/DeepSeek_OCR_paper_layouts.pdf"
pdf_file = INPUT_PATH
# 调用 OCR 函数
result = ocr_pdf(pdf_file)
if result:
# 显示部分识别结果
print("\n" + "="*60)
print("识别结果预览:")
print("="*60)
# 显示前 1000 个字符
preview_text = result['markdown'][:1000]
print(preview_text)
if len(result['markdown']) > 1000:
print("\n... (内容过长,已截断)")
print(f"\n完整内容共 {len(result['markdown'])} 个字符")
## 第五章:高阶可视化,满足你一切想象力
## 一、帕累托图
帕累托图是按照一定的类别根据数据计算出其分类所占的比例,用从高到低的顺序排列成矩形,同时展示比例累积和的图形,主要用于分析导致结果的主要因素。帕累托图与帕累托法则(又称为“二八原理”,即80%的结果是20%的原因造成的)一脉相承,通过图形体现两点重要信息:“至关重要的极少数”和“微不足道的大多数”。
接下来我们就通过帕累托图来挖掘《示例-超市》的订单数据,带大家迅速发现隐藏在数据中的重要信息。通过“二八原理”,分析贡献80%收入的产品比例,来评估该企业的销售的健康程度。
## 1.1 帕累托图应用场景
- Step1:将“销售额”拖到行功能区,产品名称拖到“列功能区”,点击下拉菜单,选择排序,依据“销售额”进行降序排序<endofsentence>
## • Step2:创建计算字段“销售额累计百分比”
## • Step3:将“销售额累计百分比”拖到行功能区,点击下拉菜单,计算依据“产品名称”,形状选择线
## • Step4:点击“销售额累计百分比”下拉菜单,选择“双轴”<|end▁of▁sentence|>
## • Step5:创建计算字段“产品名称数量百分比”
- Step6:将“产品名称数量百分比”拖到列功能区,计算依据选择“产品名称”,然后将“产品名称”拖进“标记卡-详细信息”,并设置横坐标轴格式为百分比格式
<endofsentence>
## • Step7:给纵坐标“销售额累计百分比”添加辅助参考线
## • Step8:创建计算字段“产品名称数量百分比”
## - Step9:将创建的字段“产品名称数量百分比”拖进“标记卡 - 详细信息”中,然后给横坐标“产品名称数量百分比”添加辅助参考线
<endofsentence>
## 二、盒须图
盒须图又叫箱线图,是一种常用的统计图形,用以显示数据的位置、分散程度、异常值等。箱线图主要包括6个统计量:下限、第一四分位数、中位数、第三四分位数、上限和异常值。通过绘制盒须图,观测数据在同类群体中的位置,可以知道哪些表现好,哪些表现差;比较四分位全距及线段的长短,可以看出哪些群体分散,哪些群体更集中。
第一四分位数:数据按照大小顺序排列,处于总观测数25%位置的数据
中位数:数据按照大小顺序排列,处于中间位置,即总观测数50%的数据。
第三四分位数:数据按照大小顺序排列,处于总观测数75%位置的数据为第三分位数
• 下限:第一四分位数 - 1.5 * IQR
上限:第三四分位数 + 1.5 * IQR
异常值:在上限和下限之外的数据<endofsentence>
• IQR:四分位全距,即第三四分位数与第一四分位数之差
## 2.1 盒须图应用场景
如果我们想要对《示例-超市》的订单销售额数据进行深入分析,就可以通过构建盒须图分别对比销售额的分位值、上下限值在2015-2018年的变化趋势,从而能够很直观的发现销售额的变化规律。
• Step1:将“订单日期”拖进行列功能区,“销售额”拖进行功能区
• Step2:将“订单ID”拖进“标记卡 - 详细信息”进行解聚,然后在智能推荐区域选择盒须图<endofsentence>
## 三、甘特图
甘特图,又称横道图,是以图示的方式通过活动列表和时间刻度形象地表示出任何特定项目的活动顺序和持续时间。甘特图的横轴表示时间,纵轴表示活动(项目),线条表示在整个期间上该活动或项目的持续时间,因此可以用来比较与日期相关的不同活动(项目)的持续时间长短。甘特图也常用于显示不同任务之间的依赖关系,并被普遍用于项目管理中。
## 3.1 甘特图应用场景
• Step1:创建计算字段“下单到发货间隔天数”<endofsentence>
下单到发货间隔天数
DATEDIFF('day', [订单日期], [发货日期])
计算有效。
全部
输入搜索文本
COS
COT
COUNT
COUNTD
COVAR
COVARP
DATE
DATEADD
DATEDIFF
DATENAME
DATEDIFF(date_part, start_date, end_date, [start_of_week])
返回两个日期之差,即
end_date 减
start_date。日期差的表示单位为 date_part。如果省略 start_of_week,则周起始日由为数据源配置的起始日确定。
示例:
- Step2:将“类别”和“邮寄方式”拖到行功能区,“订单日期”和“下单到发货时间间隔”拖到列功能区,度量方式选择平均值
- Step3:将“订单日期”拖到筛选器中,筛选2018年第四季度,然后在智能推荐区选择甘特图
<endofsentence>
- Step4:把订单日期进行下钻,下钻到周的粒度,然后将“邮寄方式”拖回行功能区
## 四、瀑布图
瀑布图是数据可视化分析中常见的一种图形,采用绝对值与相对值结合的方式,适用于表达数个特定数值之间的数量变化关系。对于一系列具有累计性质的正值/负值具有很好的展示功能,既可以辅助理解数据的大小,又能直观地展示出数据的增减变化,反映数据在不同时期或受不同因素的影响结果。
## 4.1 瀑布图应用场景<|end▁of▁sentence|>
在《示例-超市》订单数据中,我们想要分析不同品类销售额对总销售额的贡献以及影响大小,就可以构建不同品类产品销售额的瀑布图,在该图中,从左到右代表各品类对销售额的贡献依次减少,最右边总和代表所有品类销售额的总和
- Step1:将“子类别”拖进列功能区,“销售额”拖进行功能区,对子类别按照销售额进行降序排序
• Step2: 在“销售额”下拉菜单快速表计算中选择“汇总”
• Step3:在标记卡功能区选择“甘特条形图”<endofsentence>
## • Step4:创建计算字段“销售额负值”
## • Step5:将“销售额负值”拖进标记卡“大小”里,在菜单栏“分析-合计”中选择“显示行总计”
## • Step6:将“销售额”拖进标记卡“颜色”里,然后对颜色进行编辑<|end▁of▁sentence|>
## 五、雷达图
雷达图是专门用来进行多指标体系比较分析的专业图表,主要应用于企业经营状况的展示——收益性、生产性、流动性、安全性和成长性的评价。其主要特点是简洁、方便、精确、直观,可以将多维数据投影到同一平面上,实现多维数据的可视化。
例如,我们想要对金庸武侠小说中东邪、西毒、南帝、北丐四个人物进行综合评价,我们就可以选取几个有代表性的指标(武力值、智力值、权力值、魅力值、颜值),通过构建雷达图进行展示。<endofsentence>
## 5.1 雷达图应用场景
- Step1:导入数据源,同时选中后六列,点击下拉菜单,选择“转置”,然后将后两列分别命名为“变量”、“数值”
<endofsentence>
## Step2:转到工作表中,创建计算字段“路径”
## • Step3:创建计算字段“弧度”
## • Step4:创建横坐标轴的值“x”
## • Step5:创建横坐标轴的值“y”
<endofsentence>
- Step6:把“x”拖进列,“y”拖进行,在标记卡中选择“线”
- Step7:把“路径”拖进标记卡中的路径,然后点击下拉菜单,选择“维度”
- Step8:把“姓名”拖进标记卡中的颜色,然后点击右上角编辑颜色,把Ring1 - Ring5置灰,点击确定<endofsentence>
## 六、动态图
动态图表,顾名思义,就是根据不同的选项设置而动态变化。让读者能够从不同维度动态交互查看复杂数据的信息。
## 6.1 动态图的应用场景
世界各国GDP数值和排名每年都在变化,有些国家掉下去了,有些国家艰难地挤上来。如果我们把世界各国最近几十年的GDP排名变化做成一张动态图表,大家就能非常直观的发现中国GDP的发展,就像万米长跑最后发力冲刺,非常震撼。
• Step1:导入源数据后,将“Rank”转换为维度字段<endofsentence>
- Step2:将“Rank”拖进行功能区,“GDP”拖进列功能区,“Year”拖进“页面卡”,“Place”拖进“标记卡-标签”
## • Step3:创建一个维度计算字段“中国颜色”
• Step4:将“中国颜色”拖进“标记卡 - 颜色”,并对颜色进行编辑<endofsentence>
185个评论 206行÷1列 时间(Gigbp) 854079579879
## • Step5:在右侧“页面卡”筛选播放速度,点击进行自动播放
## 本章小结<|end▁of▁sentence|>
<endofsentence>
\ No newline at end of file
import math
from typing import List, Tuple
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, MIN_CROPS, MAX_CROPS, PROMPT, TOKENIZER
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def count_tiles(orig_width, orig_height, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=768, use_thumbnail=False):
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
return target_aspect_ratio
def dynamic_preprocess(image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=768, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
# print(target_ratios)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# print(target_aspect_ratio)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class ImageTransform:
def __init__(self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class DeepseekOCR2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast = TOKENIZER,
candidate_resolutions: Tuple[Tuple[int, int]] = [[1024, 1024]],
patch_size: int = 16,
downsample_ratio: int = 4,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<|▁pad▁|>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
# self.candidate_resolutions = candidate_resolutions # placeholder no use
self.image_size = IMAGE_SIZE
self.base_size = BASE_SIZE
# self.patch_size = patch_size
self.patch_size = 16
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
# self.downsample_ratio = downsample_ratio
self.downsample_ratio = 4
self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize)
self.tokenizer = tokenizer
# self.tokenizer = add_special_token(tokenizer)
self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': pad_token})
# add image token
# image_token_id = self.tokenizer.vocab.get(image_token)
# if image_token_id is None:
# special_tokens = [image_token]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
# special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# special_tokens = ['<image>','<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>', '<td>', '</td>', '<tr>', '</tr>']
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
# # add special tokens for SFT data
# special_tokens = ["<|User|>", "<|Assistant|>"]
# special_tokens_dict = {"additional_special_tokens": special_tokens}
# self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
# def select_best_resolution(self, image_size):
# # used for cropping
# original_width, original_height = image_size
# best_fit = None
# max_effective_resolution = 0
# min_wasted_resolution = float("inf")
# for width, height in self.candidate_resolutions:
# scale = min(width / original_width, height / original_height)
# downscaled_width, downscaled_height = int(
# original_width * scale), int(original_height * scale)
# effective_resolution = min(downscaled_width * downscaled_height,
# original_width * original_height)
# wasted_resolution = (width * height) - effective_resolution
# if effective_resolution > max_effective_resolution or (
# effective_resolution == max_effective_resolution
# and wasted_resolution < min_wasted_resolution):
# max_effective_resolution = effective_resolution
# min_wasted_resolution = wasted_resolution
# best_fit = (width, height)
# return best_fit
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- pixel_values (torch.FloatTensor): [n_patches, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (prompt is not None and images is not None
), "prompt and images must be used at the same time."
sft_format = prompt
input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, _ = images[0]
return {
"input_ids": input_ids,
"pixel_values": pixel_values,
"images_crop": images_crop,
"images_seq_mask": images_seq_mask,
"images_spatial_crop": images_spatial_crop,
"num_image_tokens": num_image_tokens,
}
# prepare = BatchFeature(
# data=dict(
# input_ids=input_ids,
# pixel_values=pixel_values,
# images_crop = images_crop,
# images_seq_mask=images_seq_mask,
# images_spatial_crop=images_spatial_crop,
# num_image_tokens=num_image_tokens,
# ),
# tensor_type="pt",
# )
# return prepare
def __call__(
self,
*,
prompt: str,
images: List,
inference_mode: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
images (List[ImageType]): the list of images;
inference_mode (bool): if True, then remove the last eos token;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt,
images=images,
inference_mode=inference_mode,
)
return prepare
def tokenize_with_images(
self,
# conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
# print(conversation)
conversation = PROMPT
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = [], [], [], []
image_shapes = []
num_image_tokens = []
tokenized_str = []
# print('image: ', len(images))
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
# if cropping:
# best_width, best_height = self.select_best_resolution(image.size)
# else:
# best_width, best_height = self.image_size, self.image_size
image_shapes.append(image.size)
if image.size[0] <= 768 and image.size[1] <= 768:
crop_ratio = [1, 1]
else:
if cropping:
# print('image-size: ', image.size)
# best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
# print('image ', image.size)
# print('open_size:', image.size)
images_crop_raw, crop_ratio = dynamic_preprocess(image, image_size=IMAGE_SIZE)
# print('crop_ratio: ', crop_ratio)
else:
# best_width, best_height = self.image_size, self.image_size
crop_ratio = [1, 1]
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
# print(crop_ratio)
"""process the global view"""
# if cropping
if self.image_size <= 768 and not cropping:
# print('directly resize')
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(image, (self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean))
images_list.append(self.image_transform(global_view))
"""record height / width crop num"""
# width_crop_num, height_crop_num = best_width // self.image_size, best_height // self.image_size
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
"""process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_crop_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
# """process the global view"""
# global_view = ImageOps.pad(image, (self.image_size, self.image_size),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# images_list.append(self.image_transform(global_view))
# """process the local views"""
# local_view = ImageOps.pad(image, (best_width, best_height),
# color=tuple(int(x * 255) for x in self.image_transform.mean))
# for i in range(0, best_height, self.image_size):
# for j in range(0, best_width, self.image_size):
# images_list.append(
# self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
# """add image tokens"""
"""add image tokens"""
num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
tokenized_image = ([self.image_token_id] * num_queries_base) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += ([self.image_token_id] * (num_queries * num_width_tiles)) * (
num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \
(f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal")
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) |
(input_ids == self.image_token_id)] = self.ignore_id
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros((1, 3, self.image_size, self.image_size)).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return [[input_ids, pixel_values, images_crop, images_seq_mask, images_spatial_crop, num_image_tokens, image_shapes]]
AutoProcessor.register("DeepseekVLV2Processor", DeepseekOCR2Processor)
import torch
from transformers import LogitsProcessor
from transformers.generation.logits_process import _calc_banned_ngram_tokens
from typing import List, Set
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
def __init__(self, ngram_size: int, window_size: int = 100, whitelist_token_ids: set = None):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(f"`window_size` has to be a strictly positive integer, but is {window_size}")
self.ngram_size = ngram_size
self.window_size = window_size
self.whitelist_token_ids = whitelist_token_ids or set()
def __call__(self, input_ids: List[int], scores: torch.FloatTensor) -> torch.FloatTensor:
if len(input_ids) < self.ngram_size:
return scores
current_prefix = tuple(input_ids[-(self.ngram_size - 1):])
search_start = max(0, len(input_ids) - self.window_size)
search_end = len(input_ids) - self.ngram_size + 1
banned_tokens = set()
for i in range(search_start, search_end):
ngram = tuple(input_ids[i:i + self.ngram_size])
if ngram[:-1] == current_prefix:
banned_tokens.add(ngram[-1])
banned_tokens = banned_tokens - self.whitelist_token_ids
if banned_tokens:
scores = scores.clone()
for token in banned_tokens:
scores[token] = -float("inf")
return scores
\ No newline at end of file
#!/bin/bash
# =============================================================================
# DeepSeek OCR vLLM 快速启动脚本
# =============================================================================
set -e
# 颜色定义
GREEN='\033[0;32m'
BLUE='\033[0;34m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m'
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PYTHON_PATH="/usr/bin/python3"
print_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
print_success() { echo -e "${GREEN}[SUCCESS]${NC} $1"; }
print_warning() { echo -e "${YELLOW}[WARNING]${NC} $1"; }
print_error() { echo -e "${RED}[ERROR]${NC} $1"; }
# 读取 .env 配置
if [ -f "${SCRIPT_DIR}/.env" ]; then
export $(grep -v '^#' "${SCRIPT_DIR}/.env" | xargs)
else
print_error "未找到 .env 配置文件"
print_info "请复制 .env.example 为 .env 并修改配置"
exit 1
fi
# 检查 Python 解释器路径
if [ -n "$PYTHON_PATH" ] && [ -f "$PYTHON_PATH" ]; then
# 使用指定的 Python 路径(miniconda 方式)
PYTHON_CMD="$PYTHON_PATH"
print_info "使用指定 Python: $PYTHON_PATH"
else
# 使用 conda 环境
if ! command -v conda &> /dev/null; then
print_error "未找到 conda 命令,请安装 conda 或设置 PYTHON_PATH"
exit 1
fi
PYTHON_CMD="conda run -n ${CONDA_ENV_NAME} python"
print_info "使用 conda 环境: ${CONDA_ENV_NAME}"
fi
print_info "DeepSeekOCR2 vLLM 快速启动..."
# 检查环境和依赖
print_info "检查环境和依赖..."
if [ -n "$PYTHON_PATH" ]; then
# 使用指定 Python 路径(miniconda 方式)
if ! $PYTHON_CMD -c "import vllm" &>/dev/null; then
print_info "安装依赖包..."
$PYTHON_CMD -m pip install -r "${SCRIPT_DIR}/requirements.txt"
fi
else
# 使用 conda 环境
if ! conda env list | grep -q "^${CONDA_ENV_NAME}\s"; then
print_info "创建conda环境..."
conda create -n "${CONDA_ENV_NAME}" python="${PYTHON_VERSION}" -y
fi
if ! conda run -n "${CONDA_ENV_NAME}" python -c "import vllm" &>/dev/null; then
print_info "安装依赖包..."
conda run -n "${CONDA_ENV_NAME}" pip install -r "${SCRIPT_DIR}/requirements.txt"
fi
fi
# 检查模型路径
if [ ! -d "$MODEL_PATH" ]; then
print_warning "模型路径不存在: $MODEL_PATH"
print_info "请修改 .env 文件中的 MODEL_PATH 变量"
exit 1
fi
# 创建日志目录
mkdir -p "${SCRIPT_DIR}/logs"
# 启动服务
print_info "启动服务 (端口 ${PORT}, GPU ${GPU_ID})..."
LOG_FILE="${SCRIPT_DIR}/logs/deepseek_ocr2_server_${PORT}_$(date +%Y%m%d_%H%M%S).log"
# 启动服务
if [ -n "$PYTHON_PATH" ]; then
# 使用指定 Python 路径(miniconda 方式)
cat > /tmp/quick_start.py << EOF
import subprocess
import sys
import os
cmd = [
sys.executable, "deepseek_ocr2_server.py",
"--model-path", "${MODEL_PATH}",
"--gpu-id", "${GPU_ID}",
"--port", "${PORT}",
"--host", "${HOST}",
"--cpu-workers", "${CPU_WORKERS}"
]
print(f"[INFO] 启动命令: {' '.join(cmd)}")
print(f"[INFO] 日志文件: ${LOG_FILE}")
with open("${LOG_FILE}", 'w') as log_file:
process = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, text=True)
print(f"[SUCCESS] 服务已启动 (PID: {process.pid})")
print(f"[INFO] API文档: http://${HOST}:${PORT}/docs")
print(f"[INFO] 健康检查: curl http://${HOST}:${PORT}/health")
EOF
$PYTHON_CMD /tmp/quick_start.py
else
# 使用 conda 环境
cat > /tmp/quick_start.py << EOF
import subprocess
import sys
import os
cmd = [
sys.executable, "deepseek_ocr2_server.py",
"--model-path", "${MODEL_PATH}",
"--gpu-id", "${GPU_ID}",
"--port", "${PORT}",
"--host", "${HOST}",
"--cpu-workers", "${CPU_WORKERS}"
]
print(f"[INFO] 启动命令: {' '.join(cmd)}")
print(f"[INFO] 日志文件: ${LOG_FILE}")
with open("${LOG_FILE}", 'w') as log_file:
process = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, text=True)
print(f"[SUCCESS] 服务已启动 (PID: {process.pid})")
print(f"[INFO] API文档: http://${HOST}:${PORT}/docs")
print(f"[INFO] 健康检查: curl http://${HOST}:${PORT}/health")
EOF
conda run -n "${CONDA_ENV_NAME}" python /tmp/quick_start.py
fi
rm -f /tmp/quick_start.py
print_success "启动完成!"
print_info "使用以下命令监控:"
echo " tail -f ${LOG_FILE}"
echo " curl http://${HOST}:${PORT}/health"
import os
import re
from tqdm import tqdm
import torch
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, MAX_CONCURRENCY, CROP_MODE, NUM_WORKERS
from concurrent.futures import ThreadPoolExecutor
import glob
from PIL import Image, ExifTags
from deepseek_ocr2 import DeepseekOCR2ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCR2Processor
ModelRegistry.register_model("DeepseekOCR2ForCausalLM", DeepseekOCR2ForCausalLM)
def correct_image_orientation(image):
try:
exif = image._getexif()
if exif is not None:
# Orientation key
for tag, value in ExifTags.TAGS.items():
if value == 'Orientation':
orientation_key = tag
break
# Orientation value
orientation = exif.get(orientation_key, 1)
#
if orientation == 3:
image = image.rotate(180, expand=True)
elif orientation == 6:
image = image.rotate(270, expand=True)
elif orientation == 8:
image = image.rotate(90, expand=True)
except Exception as e:
print(f"EXIF error: {e}")
return image
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCR2ForCausalLM"]},
block_size=256,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs = MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=40, window_size=90, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
)
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
def clean_formula(text):
formula_pattern = r'\\\[(.*?)\\\]'
def process_formula(match):
formula = match.group(1)
formula = re.sub(r'\\quad\s*\([^)]*\)', '', formula)
formula = formula.strip()
return r'\[' + formula + r'\]'
cleaned_text = re.sub(formula_pattern, process_formula, text)
return cleaned_text
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
# mathes_image = []
mathes_other = []
for a_match in matches:
mathes_other.append(a_match[0])
return matches, mathes_other
def process_single_image(image):
"""single image"""
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCR2Processor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
}
return cache_item
if __name__ == "__main__":
# INPUT_PATH = OmniDocBench images path
os.makedirs(OUTPUT_PATH, exist_ok=True)
# print('image processing until processing prompts.....')
print(f'{Colors.RED}glob images.....{Colors.RESET}')
images_path = glob.glob(f'{INPUT_PATH}/*')
images = []
for image_path in images_path:
image = Image.open(image_path)
image = correct_image_orientation(image)
# image = ImageOps.exif_transpose(image)
images.append(image.convert('RGB'))
prompt = PROMPT
# batch_inputs = []
# for image in tqdm(images):
# prompt_in = prompt
# cache_list = [
# {
# "prompt": prompt_in,
# "multi_modal_data": {"image": Image.open(image).convert('RGB')},
# }
# ]
# batch_inputs.extend(cache_list)
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
for output, image in zip(outputs_list, images_path):
content = output.outputs[0].text
# mmd_det_path = output_path + image.split('/')[-1].replace('.jpg', '_det.md')
# with open(mmd_det_path, 'w', encoding='utf-8') as afile:
# afile.write(content)
content = clean_formula(content)
matches_ref, mathes_other = re_match(content)
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
content = content.replace(a_match_other, '').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
mmd_path = output_path + image.split('/')[-1].replace('.jpg', '.md').replace('.png', '.md')
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(content)
\ No newline at end of file
import asyncio
import re
import os
import torch
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.registry import ModelRegistry
import time
from deepseek_ocr2 import DeepseekOCR2ForCausalLM
from PIL import Image, ImageDraw, ImageFont, ImageOps
import numpy as np
from tqdm import tqdm
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCR2Processor
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCR2ForCausalLM", DeepseekOCR2ForCausalLM)
def load_image(image_path):
try:
image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image)
return corrected_image
except Exception as e:
print(f"error: {e}")
try:
return Image.open(image_path)
except:
return None
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts):
result_image = draw_bounding_boxes(image, ref_texts)
return result_image
async def stream_generate(image=None, prompt=''):
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCR2ForCausalLM"]},
# torch_dtype=torch.bfloat16,
dtype="bfloat16",
# block_size=128,
max_model_len=8192,
enforce_eager=False,
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=30, window_size=90, whitelist_token_ids= {128821, 128822})] #whitelist: <td>, </td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
# ignore_eos=False,
)
request_id = f"request-{int(time.time())}"
printed_length = 0
if image and '<image>' in prompt:
request = {
"prompt": prompt,
"multi_modal_data": {"image": image}
}
elif prompt:
request = {
"prompt": prompt
}
else:
assert False, f'prompt is none!!!'
async for request_output in engine.generate(
request, sampling_params, request_id
):
if request_output.outputs:
full_text = request_output.outputs[0].text
new_text = full_text[printed_length:]
print(new_text, end='', flush=True)
printed_length = len(full_text)
final_output = full_text
print('\n')
return final_output
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
image = load_image(INPUT_PATH).convert('RGB')
if '<image>' in PROMPT:
image_features = DeepseekOCR2Processor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)
else:
image_features = ''
prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1
if save_results and '<image>' in prompt:
print('='*15 + 'save results:' + '='*15)
image_draw = image.copy()
outputs = result_out
with open(f'{OUTPUT_PATH}/result_ori.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
matches_ref, matches_images, mathes_other = re_match(outputs)
# print(matches_ref)
result = process_image_with_refs(image_draw, matches_ref)
for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
outputs = outputs.replace(a_match_image, f'![](images/' + str(idx) + '.jpg)\n')
for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
# if 'structural formula' in conversation[0]['content']:
# outputs = '<smiles>' + outputs + '</smiles>'
with open(f'{OUTPUT_PATH}/result.mmd', 'w', encoding = 'utf-8') as afile:
afile.write(outputs)
if 'line_type' in outputs:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
lines = eval(outputs)['Line']['line']
line_type = eval(outputs)['Line']['line_type']
# print(lines)
endpoints = eval(outputs)['Line']['line_endpoint']
fig, ax = plt.subplots(figsize=(3,3), dpi=200)
ax.set_xlim(-15, 15)
ax.set_ylim(-15, 15)
for idx, line in enumerate(lines):
try:
p0 = eval(line.split(' -- ')[0])
p1 = eval(line.split(' -- ')[-1])
if line_type[idx] == '--':
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
else:
ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
ax.scatter(p0[0], p0[1], s=5, color = 'k')
ax.scatter(p1[0], p1[1], s=5, color = 'k')
except:
pass
for endpoint in endpoints:
label = endpoint.split(': ')[0]
(x, y) = eval(endpoint.split(': ')[1])
ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
fontsize=5, fontweight='light')
try:
if 'Circle' in eval(outputs).keys():
circle_centers = eval(outputs)['Circle']['circle_center']
radius = eval(outputs)['Circle']['radius']
for center, r in zip(circle_centers, radius):
center = eval(center.split(': ')[1])
circle = Circle(center, radius=r, fill=False, edgecolor='black', linewidth=0.8)
ax.add_patch(circle)
except:
pass
plt.savefig(f'{OUTPUT_PATH}/geo.jpg')
plt.close()
result.save(f'{OUTPUT_PATH}/result_with_boxes.jpg')
\ No newline at end of file
import os
import fitz
import img2pdf
import io
import re
from tqdm import tqdm
import torch
from concurrent.futures import ThreadPoolExecutor
if torch.version.cuda == '11.8':
os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"
os.environ['VLLM_USE_V1'] = '0'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr2 import DeepseekOCR2ForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCR2Processor
ModelRegistry.register_model("DeepseekOCR2ForCausalLM", DeepseekOCR2ForCausalLM)
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCR2ForCausalLM"]},
block_size=64,
enforce_eager=False,
trust_remote_code=True,
max_model_len=8192,
swap_space=0,
max_num_seqs=MAX_CONCURRENCY,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
disable_mm_preprocessor_cache=True
)
logits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids= {128821, 128822})] #window for fast;whitelist_token_ids: <td>,</td>
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
logits_processors=logits_processors,
skip_special_tokens=False,
include_stop_str_in_output=True,
)
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
RESET = '\033[0m'
def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
"""
pdf2images
"""
images = []
pdf_document = fitz.open(pdf_path)
zoom = dpi / 72.0
matrix = fitz.Matrix(zoom, zoom)
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
pixmap = page.get_pixmap(matrix=matrix, alpha=False)
Image.MAX_IMAGE_PIXELS = None
if image_format.upper() == "PNG":
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
else:
img_data = pixmap.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode in ('RGBA', 'LA'):
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
img = background
images.append(img)
pdf_document.close()
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
image_bytes_list = []
for img in pil_images:
if img.mode != 'RGB':
img = img.convert('RGB')
img_buffer = io.BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img_bytes = img_buffer.getvalue()
image_bytes_list.append(img_bytes)
try:
pdf_bytes = img2pdf.convert(image_bytes_list)
with open(output_path, "wb") as f:
f.write(pdf_bytes)
except Exception as e:
print(f"error: {e}")
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
if '<|ref|>image<|/ref|>' in a_match[0]:
mathes_image.append(a_match[0])
else:
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
except Exception as e:
print(e)
return None
return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
draw2 = ImageDraw.Draw(overlay)
# except IOError:
font = ImageFont.load_default()
img_idx = 0
for i, ref in enumerate(refs):
try:
result = extract_coordinates_and_label(ref, image_width, image_height)
if result:
label_type, points_list = result
color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
color_a = color + (20, )
for points in points_list:
x1, y1, x2, y2 = points
x1 = int(x1 / 999 * image_width)
y1 = int(y1 / 999 * image_height)
x2 = int(x2 / 999 * image_width)
y2 = int(y2 / 999 * image_height)
if label_type == 'image':
try:
cropped = image.crop((x1, y1, x2, y2))
cropped.save(f"{OUTPUT_PATH}/images/{jdx}_{img_idx}.jpg")
except Exception as e:
print(e)
pass
img_idx += 1
try:
if label_type == 'title':
draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
else:
draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
text_x = x1
text_y = max(0, y1 - 15)
text_bbox = draw.textbbox((0, 0), label_type, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
fill=(255, 255, 255, 30))
draw.text((text_x, text_y), label_type, font=font, fill=color)
except:
pass
except:
continue
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx):
result_image = draw_bounding_boxes(image, ref_texts, jdx)
return result_image
def process_single_image(image):
"""single image"""
prompt_in = prompt
cache_item = {
"prompt": prompt_in,
"multi_modal_data": {"image": DeepseekOCR2Processor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
}
return cache_item
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
os.makedirs(f'{OUTPUT_PATH}/images', exist_ok=True)
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
images = pdf_to_images_high_quality(INPUT_PATH)
prompt = PROMPT
# batch_inputs = []
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
batch_inputs = list(tqdm(
executor.map(process_single_image, images),
total=len(images),
desc="Pre-processed images"
))
# for image in tqdm(images):
# prompt_in = prompt
# cache_list = [
# {
# "prompt": prompt_in,
# "multi_modal_data": {"image": DeepseekOCR2Processor().tokenize_with_images(images = [image], bos=True, eos=True, cropping=CROP_MODE)},
# }
# ]
# batch_inputs.extend(cache_list)
outputs_list = llm.generate(
batch_inputs,
sampling_params=sampling_params
)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
contents_det = ''
contents = ''
draw_images = []
jdx = 0
for output, img in zip(outputs_list, images):
content = output.outputs[0].text
if '<|end▁of▁sentence|>' in content: # repeat no eos
content = content.replace('<|end▁of▁sentence|>', '')
else:
if SKIP_REPEAT:
continue
page_num = f'\n<--- Page Split --->'
contents_det += content + f'\n{page_num}\n'
image_draw = img.copy()
matches_ref, matches_images, mathes_other = re_match(content)
# print(matches_ref)
result_image = process_image_with_refs(image_draw, matches_ref, jdx)
draw_images.append(result_image)
for idx, a_match_image in enumerate(matches_images):
content = content.replace(a_match_image, f'![](images/' + str(jdx) + '_' + str(idx) + '.jpg)\n')
for idx, a_match_other in enumerate(mathes_other):
content = content.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:').replace('\n\n\n\n', '\n\n').replace('\n\n\n', '\n\n')
contents += content + f'\n{page_num}\n'
jdx += 1
with open(mmd_det_path, 'w', encoding='utf-8') as afile:
afile.write(contents_det)
with open(mmd_path, 'w', encoding='utf-8') as afile:
afile.write(contents)
pil_to_pdf_img2pdf(draw_images, pdf_out_path)
# DeepSeekOCR/DeepSeekOCR2 在线 API 服务部署说明
该服务基于DeepSeekOCR/DeepSeekOCR2模型实现了光学字符识别能力,支持通过HTTP API调用进行图像与PDF文字识别。
## 依赖环境
VLLM镜像:image.sourcefind.cn:5000/dcu/admin/base/vllm:0.8.5-ubuntu22.04-dtk25.04.1-rc5-das1.6-py3.10-20250724
百度网盘链接如下:
通过网盘分享的文件:ds-ocr
链接: https://pan.baidu.com/s/1BJ_gmtScvboEbabTXjHn4Q?pwd=jbi8 提取码: jbi8
--来自百度网盘超级会员v6的分享
## 快速开始
### 1. 在线推理
cd DeepSeek-OCR-vllm
服务启动
python deepseek_ocr_server.py --model-path DeepSeekOCR模型路径
在线推理
python online_test.py
### 2. 离线推理
cd DeepSeek-OCR-vllm
离线推理
bash offline_test.sh
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment