Commit a29489ef authored by myhloli's avatar myhloli
Browse files

refactor: update config file name and enhance model path handling

parent a149a49c
...@@ -77,6 +77,7 @@ class HuggingfacePredictor(BasePredictor): ...@@ -77,6 +77,7 @@ class HuggingfacePredictor(BasePredictor):
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
**kwargs, **kwargs,
) )
setattr(self.model.config, "_name_or_path", model_path)
self.model.eval() self.model.eval()
vision_tower = self.model.get_model().vision_tower vision_tower = self.model.get_model().vision_tower
......
...@@ -158,6 +158,9 @@ def do_parse( ...@@ -158,6 +158,9 @@ def do_parse(
logger.info(f"local output dir is {local_md_dir}") logger.info(f"local output dir is {local_md_dir}")
else: else:
if backend.startswith("vlm-"):
backend = backend[4:]
f_draw_span_bbox = False f_draw_span_bbox = False
parse_method = "vlm" parse_method = "vlm"
for idx, pdf_bytes in enumerate(pdf_bytes_list): for idx, pdf_bytes in enumerate(pdf_bytes_list):
...@@ -216,10 +219,10 @@ def do_parse( ...@@ -216,10 +219,10 @@ def do_parse(
if __name__ == "__main__": if __name__ == "__main__":
# pdf_path = "../../demo/pdfs/demo2.pdf" pdf_path = "../../demo/pdfs/demo2.pdf"
pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg" # pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg"
try: try:
do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=20,) do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=1, backend='vlm-huggingface')
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
...@@ -79,8 +79,12 @@ class SiglipVisionTower(nn.Module): ...@@ -79,8 +79,12 @@ class SiglipVisionTower(nn.Module):
def build_vision_tower(config: Mineru2QwenConfig): def build_vision_tower(config: Mineru2QwenConfig):
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", "")) vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
model_path = getattr(config, "_name_or_path", "")
if "siglip" in vision_tower.lower(): if "siglip" in vision_tower.lower():
return SiglipVisionTower(vision_tower) if model_path:
return SiglipVisionTower(f"{model_path}/{vision_tower}")
else:
return SiglipVisionTower(vision_tower)
raise ValueError(f"Unknown vision tower: {vision_tower}") raise ValueError(f"Unknown vision tower: {vision_tower}")
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from loguru import logger from loguru import logger
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json') CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
def read_config(): def read_config():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment