Commit 66792d8d authored by chenych's avatar chenych
Browse files

Modify README

parent d890d8fe
...@@ -4,9 +4,6 @@ import argparse ...@@ -4,9 +4,6 @@ import argparse
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
os.environ["HIP_VISIBLE_DEVICES"] = '0'
parse = argparse.ArgumentParser() parse = argparse.ArgumentParser()
parse.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-OCR') parse.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-OCR')
parse.add_argument('--image_file', type=str, default='./doc/test.png') parse.add_argument('--image_file', type=str, default='./doc/test.png')
......
import asyncio import asyncio
import re import re
import os import os
import argparse
os.environ["HIP_VISIBLE_DEVICES"] = '0'
from vllm import AsyncLLMEngine, SamplingParams from vllm import AsyncLLMEngine, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -20,14 +17,10 @@ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE ...@@ -20,14 +17,10 @@ from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, CROP_MODE
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
def load_image(image_path): def load_image(image_path):
try: try:
image = Image.open(image_path) image = Image.open(image_path)
corrected_image = ImageOps.exif_transpose(image) corrected_image = ImageOps.exif_transpose(image)
return corrected_image return corrected_image
except Exception as e: except Exception as e:
print(f"error: {e}") print(f"error: {e}")
try: try:
...@@ -35,7 +28,6 @@ def load_image(image_path): ...@@ -35,7 +28,6 @@ def load_image(image_path):
except: except:
return None return None
def re_match(text): def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)
...@@ -50,10 +42,7 @@ def re_match(text): ...@@ -50,10 +42,7 @@ def re_match(text):
mathes_other.append(a_match[0]) mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height): def extract_coordinates_and_label(ref_text, image_width, image_height):
try: try:
label_type = ref_text[1] label_type = ref_text[1]
cor_list = eval(ref_text[2]) cor_list = eval(ref_text[2])
...@@ -63,7 +52,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height): ...@@ -63,7 +52,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
return (label_type, cor_list) return (label_type, cor_list)
def draw_bounding_boxes(image, refs): def draw_bounding_boxes(image, refs):
image_width, image_height = image.size image_width, image_height = image.size
...@@ -136,7 +124,6 @@ def process_image_with_refs(image, ref_texts): ...@@ -136,7 +124,6 @@ def process_image_with_refs(image, ref_texts):
async def stream_generate(image=None, prompt=''): async def stream_generate(image=None, prompt=''):
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model=MODEL_PATH, model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
...@@ -189,8 +176,6 @@ async def stream_generate(image=None, prompt=''): ...@@ -189,8 +176,6 @@ async def stream_generate(image=None, prompt=''):
return final_output return final_output
if __name__ == "__main__": if __name__ == "__main__":
os.makedirs(OUTPUT_PATH, exist_ok=True) os.makedirs(OUTPUT_PATH, exist_ok=True)
...@@ -208,8 +193,6 @@ if __name__ == "__main__": ...@@ -208,8 +193,6 @@ if __name__ == "__main__":
prompt = PROMPT prompt = PROMPT
result_out = asyncio.run(stream_generate(image_features, prompt)) result_out = asyncio.run(stream_generate(image_features, prompt))
save_results = 1 save_results = 1
if save_results and '<image>' in prompt: if save_results and '<image>' in prompt:
......
...@@ -3,27 +3,20 @@ import fitz ...@@ -3,27 +3,20 @@ import fitz
import img2pdf import img2pdf
import io import io
import re import re
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE from config import MODEL_PATH, INPUT_PATH, OUTPUT_PATH, PROMPT, SKIP_REPEAT, MAX_CONCURRENCY, NUM_WORKERS, CROP_MODE
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import numpy as np
from deepseek_ocr import DeepseekOCRForCausalLM from deepseek_ocr import DeepseekOCRForCausalLM
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from process.ngram_norepeat import NoRepeatNGramLogitsProcessor from process.ngram_norepeat import NoRepeatNGramLogitsProcessor
from process.image_process import DeepseekOCRProcessor from process.image_process import DeepseekOCRProcessor
ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM) ModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)
llm = LLM( llm = LLM(
model=MODEL_PATH, model=MODEL_PATH,
hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]},
...@@ -90,7 +83,6 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"): ...@@ -90,7 +83,6 @@ def pdf_to_images_high_quality(pdf_path, dpi=144, image_format="PNG"):
return images return images
def pil_to_pdf_img2pdf(pil_images, output_path): def pil_to_pdf_img2pdf(pil_images, output_path):
if not pil_images: if not pil_images:
return return
...@@ -113,13 +105,10 @@ def pil_to_pdf_img2pdf(pil_images, output_path): ...@@ -113,13 +105,10 @@ def pil_to_pdf_img2pdf(pil_images, output_path):
except Exception as e: except Exception as e:
print(f"error: {e}") print(f"error: {e}")
def re_match(text): def re_match(text):
pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)
mathes_image = [] mathes_image = []
mathes_other = [] mathes_other = []
for a_match in matches: for a_match in matches:
...@@ -129,10 +118,7 @@ def re_match(text): ...@@ -129,10 +118,7 @@ def re_match(text):
mathes_other.append(a_match[0]) mathes_other.append(a_match[0])
return matches, mathes_image, mathes_other return matches, mathes_image, mathes_other
def extract_coordinates_and_label(ref_text, image_width, image_height): def extract_coordinates_and_label(ref_text, image_width, image_height):
try: try:
label_type = ref_text[1] label_type = ref_text[1]
cor_list = eval(ref_text[2]) cor_list = eval(ref_text[2])
...@@ -142,7 +128,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height): ...@@ -142,7 +128,6 @@ def extract_coordinates_and_label(ref_text, image_width, image_height):
return (label_type, cor_list) return (label_type, cor_list)
def draw_bounding_boxes(image, refs, jdx): def draw_bounding_boxes(image, refs, jdx):
image_width, image_height = image.size image_width, image_height = image.size
...@@ -209,12 +194,10 @@ def draw_bounding_boxes(image, refs, jdx): ...@@ -209,12 +194,10 @@ def draw_bounding_boxes(image, refs, jdx):
img_draw.paste(overlay, (0, 0), overlay) img_draw.paste(overlay, (0, 0), overlay)
return img_draw return img_draw
def process_image_with_refs(image, ref_texts, jdx): def process_image_with_refs(image, ref_texts, jdx):
result_image = draw_bounding_boxes(image, ref_texts, jdx) result_image = draw_bounding_boxes(image, ref_texts, jdx)
return result_image return result_image
def process_single_image(image): def process_single_image(image):
"""single image""" """single image"""
prompt_in = prompt prompt_in = prompt
...@@ -232,10 +215,7 @@ if __name__ == "__main__": ...@@ -232,10 +215,7 @@ if __name__ == "__main__":
print(f'{Colors.RED}PDF loading .....{Colors.RESET}') print(f'{Colors.RED}PDF loading .....{Colors.RESET}')
images = pdf_to_images_high_quality(INPUT_PATH) images = pdf_to_images_high_quality(INPUT_PATH)
prompt = PROMPT prompt = PROMPT
# batch_inputs = [] # batch_inputs = []
...@@ -247,7 +227,6 @@ if __name__ == "__main__": ...@@ -247,7 +227,6 @@ if __name__ == "__main__":
desc="Pre-processed images" desc="Pre-processed images"
)) ))
# for image in tqdm(images): # for image in tqdm(images):
# prompt_in = prompt # prompt_in = prompt
...@@ -265,12 +244,10 @@ if __name__ == "__main__": ...@@ -265,12 +244,10 @@ if __name__ == "__main__":
sampling_params=sampling_params sampling_params=sampling_params
) )
output_path = OUTPUT_PATH output_path = OUTPUT_PATH
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd') mmd_det_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_det.mmd')
mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd') mmd_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('pdf', 'mmd')
pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf') pdf_out_path = output_path + '/' + INPUT_PATH.split('/')[-1].replace('.pdf', '_layouts.pdf')
...@@ -287,7 +264,6 @@ if __name__ == "__main__": ...@@ -287,7 +264,6 @@ if __name__ == "__main__":
if SKIP_REPEAT: if SKIP_REPEAT:
continue continue
page_num = f'\n<--- Page Split --->' page_num = f'\n<--- Page Split --->'
contents_det += content + f'\n{page_num}\n' contents_det += content + f'\n{page_num}\n'
......
...@@ -65,13 +65,15 @@ pip install -r requirements.txt ...@@ -65,13 +65,15 @@ pip install -r requirements.txt
### transformers ### transformers
> 模型地址,测试图片路径,输出路径根据实际情况修改。 > 模型地址,测试图片路径,输出路径根据实际情况修改。
```bash ```bash
python DeepSeek-OCR-hf/run_dpsk_ocr.py --model_name_or_path=deepseek-ai/DeepSeek-OCR --image_path=./doc/test.png --output_path=./output export HIP_VISIBLE_DEVICES=0
python DeepSeek-OCR-hf/run_dpsk_ocr.py --model_name_or_path=deepseek-ai/DeepSeek-OCR --image_file=./doc/test.png --output_path=./output
``` ```
### vllm ### vllm
> 模型地址,测试图片路径,输出路径请根据实际情况在`DeepSeek-OCR-vllm/config.py`中修改。 > 模型地址,测试图片路径,输出路径请根据实际情况在`DeepSeek-OCR-vllm/config.py`中修改。
```bash ```bash
export VLLM_USE_V1=0 export VLLM_USE_V1=0
export HIP_VISIBLE_DEVICES=0
# image:流式输出 # image:流式输出
python DeepSeek-OCR-vllm/run_dpsk_ocr_image.py python DeepSeek-OCR-vllm/run_dpsk_ocr_image.py
# pdf # pdf
...@@ -85,7 +87,7 @@ python DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py ...@@ -85,7 +87,7 @@ python DeepSeek-OCR-vllm/run_dpsk_ocr_pdf.py
</div> </div>
### 精度 ### 精度
DCU与GPU精度一致,推理框架:vllm DCU与GPU精度一致,推理框架:pytorch
## 应用场景 ## 应用场景
### 算法类别 ### 算法类别
......
...@@ -7,6 +7,6 @@ modelDescription=DeepSeek 推出了全新的视觉文本压缩模型 DeepSeek-OC ...@@ -7,6 +7,6 @@ modelDescription=DeepSeek 推出了全新的视觉文本压缩模型 DeepSeek-OC
# 应用场景 # 应用场景
appScenario=推理,OCR,制造,金融,交通,教育,医疗 appScenario=推理,OCR,制造,金融,交通,教育,医疗
# 框架类型 # 框架类型
frameType=pytorch,vllm frameType=pytorch
# 加速卡类型 # 加速卡类型
accelerateType=K100AI accelerateType=K100AI
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment