Commit 66792d8d authored by chenych's avatar chenych
Browse files

Modify README

parent d890d8fe
......@@ -4,9 +4,6 @@ import argparse
from transformers import AutoModel, AutoTokenizer
os.environ["HIP_VISIBLE_DEVICES"] = '0'
parse = argparse.ArgumentParser()
parse.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-OCR')
parse.add_argument('--image_file', type=str, default='./doc/test.png')
......
import asyncio
import re
import os
import argparse
os.environ["HIP_VISIBLE_DEVICES"] = '0'
from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
......@@ -20,14 +17,10 @@ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path):
try:
image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image)
return corrected_image
except Exception as e:
print(f"error: {e}")
try:
......@@ -35,7 +28,6 @@ def load_image(image_path):
except:
return None
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
......@@ -50,10 +42,7 @@ def re_match(text):
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
......@@ -63,7 +52,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
return (label_type, cor_list)
def draw_bounding_boxes(image, refs):
image_width, image_height = image.size
......@@ -136,7 +124,6 @@ def process_image_with_refs(image, ref_texts):
async def stream_generate(image=None, prompt=''):
engine_args = AsyncEngineArgs(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
......@@ -189,8 +176,6 @@ async def stream_generate(image=None, prompt=''):
return final_output
if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True)
......@@ -208,8 +193,6 @@ if __name__ == "__main__":
prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1
if save_results and '<image>' in prompt:
......
......@@ -3,27 +3,20 @@ import fitz
import img2pdf
import io
import re
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM
from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
llm = LLM(
model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
......@@ -90,7 +83,6 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
return images
def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images:
return
......@@ -113,13 +105,10 @@ def pil_to_pdf_img2pdf(pil_images, output_path):
except Exception as e:
print(f"error: {e}")
def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL)
mathes_image = []
mathes_other = []
for a_match in matches:
......@@ -129,10 +118,7 @@ def re_match(text):
mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height):
try:
label_type = ref_text[1]
cor_list = eval(ref_text[2])
......@@ -142,7 +128,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size
......@@ -209,12 +194,10 @@ def draw_bounding_boxes(image, refs, jdx):
img_draw.paste(overlay, (0, 0), overlay)
return img_draw
def process_image_with_refs(image, ref_texts, jdx):
result_image = draw_bounding_boxes(image, ref_texts, jdx)
return result_image
def process_single_image(image):
"""single image"""
prompt_in = prompt
......@@ -232,10 +215,7 @@ if __name__ == "__main__":
print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
images = pdf_to_images_high_quality(INPUT_PATH)
prompt = PROMPT
# batch_inputs = []
......@@ -247,7 +227,6 @@ if __name__ == "__main__":
desc="Pre-processed images"
))
# for image in tqdm(images):
# prompt_in = prompt
......@@ -265,12 +244,10 @@ if __name__ == "__main__":
sampling_params=sampling_params
)
output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
......@@ -287,7 +264,6 @@ if __name__ == "__main__":
if SKIP_REPEAT:
continue
page_num = f'\n<--- Page Split --->'
contents_det += content + f'\n{page_num}\n'
......
......@@ -65,13 +65,15 @@ pip install -r requirements.txt
### transformers
> 模型地址,测试图片路径,输出路径根据实际情况修改。
```bash
python DeepSeek-OCR-hf/run_dpsk_ocr.py --model_name_or_path=deepseek-ai/DeepSeek-OCR --image_path=./doc/test.png --output_path=./output
export HIP_VISIBLE_DEVICES=0
python DeepSeek-OCR-hf/run_dpsk_ocr.py --model_name_or_path=deepseek-ai/DeepSeek-OCR --image_file=./doc/test.png --output_path=./output
```
### vllm
> 模型地址,测试图片路径,输出路径请根据实际情况在`DeepSeek-OCR-vllm/config.py`中修改。
```bash
export VLLM_USE_V1=0
export HIP_VISIBLE_DEVICES=0
# image:流式输出
python DeepSeek-OCR-vllm/run_dpsk_ocr_image.py
# pdf
......@@ -85,7 +87,7 @@ python DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py
</div>
### 精度
DCU与GPU精度一致,推理框架:vllm
DCU与GPU精度一致,推理框架:pytorch
## 应用场景
### 算法类别
......
......@@ -7,6 +7,6 @@ modelDescription=DeepSeek 推出了全新的视觉文本压缩模型 DeepSeek-OC
# 应用场景
appScenario=推理,OCR,制造,金融,交通,教育,医疗
# 框架类型
frameType=pytorch,vllm
frameType=pytorch
# 加速卡类型
accelerateType=K100AI
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment