Commit 8e55a526 authored by Jin Zhen Jiang's avatar Jin Zhen Jiang
Browse files

feat: add mineru-vlm backend.

parent 6f8a9610
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, List, Optional, Union
DEFAULT_SYSTEM_PROMPT = (
"A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
)
DEFAULT_USER_PROMPT = "Document Parsing:"
DEFAULT_TEMPERATURE = 0.0
DEFAULT_TOP_P = 0.01
DEFAULT_TOP_K = 1
DEFAULT_REPETITION_PENALTY = 1.0
DEFAULT_PRESENCE_PENALTY = 0.0
DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
DEFAULT_MAX_NEW_TOKENS = 16384
class BasePredictor(ABC):
system_prompt = DEFAULT_SYSTEM_PROMPT
def __init__(
self,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.presence_penalty = presence_penalty
self.no_repeat_ngram_size = no_repeat_ngram_size
self.max_new_tokens = max_new_tokens
@abstractmethod
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str: ...
@abstractmethod
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]: ...
@abstractmethod
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]: ...
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return await asyncio.to_thread(
self.predict,
image,
prompt,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
return await asyncio.to_thread(
self.batch_predict,
images,
prompts,
temperature,
top_p,
top_k,
repetition_penalty,
presence_penalty,
no_repeat_ngram_size,
max_new_tokens,
)
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def synced_predict():
for chunk in self.stream_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
):
asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
asyncio.create_task(
asyncio.to_thread(synced_predict),
)
while True:
chunk = await queue.get()
if chunk is None:
return
assert isinstance(chunk, str)
yield chunk
def build_prompt(self, prompt: str) -> str:
if prompt.startswith("<|im_start|>"):
return prompt
if not prompt:
prompt = DEFAULT_USER_PROMPT
return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
# Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
# if "Document OCR" in prompt:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
# else:
# return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
def close(self):
pass
from io import BytesIO
from typing import Iterable, List, Optional, Union
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer, BitsAndBytesConfig
from ...model.vlm_hf_model import Mineru2QwenForCausalLM
from ...model.vlm_hf_model.image_processing_mineru2 import process_images
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import load_resource
class HuggingfacePredictor(BasePredictor):
def __init__(
self,
model_path: str,
device_map="auto",
device="cuda",
torch_dtype="auto",
load_in_8bit=False,
load_in_4bit=False,
use_flash_attn=False,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
**kwargs,
):
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_in_8bit:
kwargs["load_in_8bit"] = True
elif load_in_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch_dtype
if use_flash_attn:
kwargs["attn_implementation"] = "flash_attention_2"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = Mineru2QwenForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs,
)
self.model.eval()
vision_tower = self.model.get_model().vision_tower
if device_map != "auto":
vision_tower.to(device=device_map, dtype=self.model.dtype)
self.image_processor = vision_tower.image_processor
self.eos_token_id = self.model.config.eos_token_id
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> str:
prompt = self.build_prompt(prompt)
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
do_sample = (temperature > 0.0) and (top_k > 1)
generate_kwargs = {
"repetition_penalty": repetition_penalty,
"no_repeat_ngram_size": no_repeat_ngram_size,
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
}
if do_sample:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
generate_kwargs["top_k"] = top_k
if isinstance(image, str):
image = load_resource(image)
image_obj = Image.open(BytesIO(image))
image_tensor = process_images([image_obj], self.image_processor, self.model.config)
image_tensor = image_tensor[0].unsqueeze(0)
image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
image_sizes = [[*image_obj.size]]
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device=self.model.device)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
use_cache=True,
**generate_kwargs,
**kwargs,
)
# Remove the last token if it is the eos_token_id
if len(output_ids[0]) > 0 and output_ids[0, -1] == self.eos_token_id:
output_ids = output_ids[:, :-1]
output = self.tokenizer.batch_decode(
output_ids,
skip_special_tokens=False,
)[0].strip()
return output
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, # not supported by hf
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
**kwargs,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
outputs = []
for prompt, image in tqdm(zip(prompts, images), total=len(images), desc="Predict"):
output = self.predict(
image,
prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
outputs.append(output)
return outputs
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .sglang_client_predictor import SglangClientPredictor
hf_loaded = False
try:
from .hf_predictor import HuggingfacePredictor
hf_loaded = True
except ImportError as e:
logger.warning("hf is not installed. If you are not using huggingface, you can ignore this warning.")
engine_loaded = False
try:
from sglang.srt.server_args import ServerArgs
from .sglang_engine_predictor import SglangEnginePredictor
engine_loaded = True
except Exception as e:
logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
def get_predictor(
backend: str = "sglang-client",
model_path: str | None = None,
server_url: str | None = None,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
**kwargs,
) -> BasePredictor:
start_time = time.time()
if backend == "huggingface":
if not model_path:
raise ValueError("model_path must be provided for huggingface backend.")
if not hf_loaded:
raise ImportError(
"transformers is not installed, so huggingface backend cannot be used. "
"If you need to use huggingface backend, please install transformers first."
)
predictor = HuggingfacePredictor(
model_path=model_path,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
**kwargs,
)
elif backend == "sglang-engine":
if not model_path:
raise ValueError("model_path must be provided for sglang-engine backend.")
if not engine_loaded:
raise ImportError(
"sglang is not installed, so sglang-engine backend cannot be used. "
"If you need to use sglang-engine backend for inference, "
"please install sglang[all]==0.4.6.post4 or a newer version."
)
predictor = SglangEnginePredictor(
server_args=ServerArgs(model_path, **kwargs),
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
elif backend == "sglang-client":
if not server_url:
raise ValueError("server_url must be provided for sglang-client backend.")
predictor = SglangClientPredictor(
server_url=server_url,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
http_timeout=http_timeout,
)
else:
raise ValueError(f"Unsupported backend: {backend}. Supports: huggingface, sglang-engine, sglang-client.")
elapsed = round(time.time() - start_time, 2)
logger.info(f"get_predictor cost: {elapsed}s")
return predictor
import asyncio
import json
import re
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
import httpx
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
from .utils import aio_load_resource, load_resource
class SglangClientPredictor(BasePredictor):
def __init__(
self,
server_url: str,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
http_timeout: int = 600,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.http_timeout = http_timeout
base_url = self.get_base_url(server_url)
self.check_server_health(base_url)
self.model_path = self.get_model_path(base_url)
self.server_url = f"{base_url}/generate"
@staticmethod
def get_base_url(server_url: str) -> str:
matched = re.match(r"^(https?://[^/]+)", server_url)
if not matched:
raise ValueError(f"Invalid server URL: {server_url}")
return matched.group(1)
def check_server_health(self, base_url: str):
try:
response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
)
def get_model_path(self, base_url: str) -> str:
try:
response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
except httpx.ConnectError:
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
if response.status_code != 200:
raise RuntimeError(
f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
)
return response.json()["model_path"]
def build_sampling_params(
self,
temperature: Optional[float],
top_p: Optional[float],
top_k: Optional[int],
repetition_penalty: Optional[float],
presence_penalty: Optional[float],
no_repeat_ngram_size: Optional[int],
max_new_tokens: Optional[int],
) -> dict:
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
return {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
def build_request_body(
self,
image: bytes,
prompt: str,
sampling_params: dict,
) -> dict:
image_base64 = b64encode(image).decode("utf-8")
return {
"text": prompt,
"image_data": image_base64,
"sampling_params": sampling_params,
"modalities": ["image"],
}
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
response_body = response.json()
return response_body["text"]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
task = self.aio_batch_predict(
images=images,
prompts=prompts,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
max_concurrency=max_concurrency,
)
if loop is not None:
return loop.run_until_complete(task)
else:
return asyncio.run(task)
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
with httpx.stream(
"POST",
self.server_url,
json=request_body,
timeout=self.http_timeout,
) as response:
pos = 0
for chunk in response.iter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
async_client: Optional[httpx.AsyncClient] = None,
) -> str:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
if async_client is None:
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
response = await client.post(self.server_url, json=request_body)
response_body = response.json()
else:
response = await async_client.post(self.server_url, json=request_body)
response_body = response.json()
return response_body["text"]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
outputs = [""] * len(images)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
outputs[idx] = output
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
tasks = []
for idx, (prompt, image) in enumerate(zip(prompts, images)):
tasks.append(predict_with_semaphore(idx, image, prompt, client))
await asyncio.gather(*tasks)
return outputs
async def aio_batch_predict_as_iter(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
max_concurrency: int = 100,
) -> AsyncIterable[Tuple[int, str]]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
semaphore = asyncio.Semaphore(max_concurrency)
async def predict_with_semaphore(
idx: int,
image: str | bytes,
prompt: str,
async_client: httpx.AsyncClient,
):
async with semaphore:
output = await self.aio_predict(
image=image,
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
async_client=async_client,
)
return (idx, output)
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
pending: Set[asyncio.Task[Tuple[int, str]]] = set()
for idx, (prompt, image) in enumerate(zip(prompts, images)):
pending.add(
asyncio.create_task(
predict_with_semaphore(idx, image, prompt, client),
)
)
while len(pending) > 0:
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
yield task.result()
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
prompt = self.build_prompt(prompt)
sampling_params = self.build_sampling_params(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
if isinstance(image, str):
image = await aio_load_resource(image)
request_body = self.build_request_body(image, prompt, sampling_params)
request_body["stream"] = True
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
async with client.stream(
"POST",
self.server_url,
json=request_body,
) as response:
pos = 0
async for chunk in response.aiter_lines():
if not (chunk or "").startswith("data:"):
continue
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
# meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text
from base64 import b64encode
from typing import AsyncIterable, Iterable, List, Optional, Union
from sglang.srt.server_args import ServerArgs
from ...model.vlm_sglang_model.engine import BatchEngine
from .base_predictor import (
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_NO_REPEAT_NGRAM_SIZE,
DEFAULT_PRESENCE_PENALTY,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
BasePredictor,
)
class SglangEnginePredictor(BasePredictor):
def __init__(
self,
server_args: ServerArgs,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
) -> None:
super().__init__(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
self.engine = BatchEngine(server_args=server_args)
def load_image_string(self, image: str | bytes) -> str:
if not isinstance(image, (str, bytes)):
raise ValueError("Image must be a string or bytes.")
if isinstance(image, bytes):
return b64encode(image).decode("utf-8")
if image.startswith("file://"):
return image[len("file://") :]
return image
def predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
return self.batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)[0]
def batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = self.engine.generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
return [item["text"] for item in output]
def stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> Iterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
async def aio_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> str:
output = await self.aio_batch_predict(
[image], # type: ignore
[prompt],
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
presence_penalty=presence_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
max_new_tokens=max_new_tokens,
)
return output[0]
async def aio_batch_predict(
self,
images: List[str] | List[bytes],
prompts: Union[List[str], str] = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> List[str]:
if not isinstance(prompts, list):
prompts = [prompts] * len(images)
assert len(prompts) == len(images), "Length of prompts and images must match."
prompts = [self.build_prompt(prompt) for prompt in prompts]
if temperature is None:
temperature = self.temperature
if top_p is None:
top_p = self.top_p
if top_k is None:
top_k = self.top_k
if repetition_penalty is None:
repetition_penalty = self.repetition_penalty
if presence_penalty is None:
presence_penalty = self.presence_penalty
if no_repeat_ngram_size is None:
no_repeat_ngram_size = self.no_repeat_ngram_size
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
# see SamplingParams for more details
sampling_params = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"presence_penalty": presence_penalty,
"custom_params": {
"no_repeat_ngram_size": no_repeat_ngram_size,
},
"max_new_tokens": max_new_tokens,
"skip_special_tokens": False,
}
image_strings = [self.load_image_string(img) for img in images]
output = await self.engine.async_generate(
prompt=prompts,
image_data=image_strings,
sampling_params=sampling_params,
)
ret = []
for item in output: # type: ignore
ret.append(item["text"])
return ret
async def aio_stream_predict(
self,
image: str | bytes,
prompt: str = "",
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
repetition_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
max_new_tokens: Optional[int] = None,
) -> AsyncIterable[str]:
raise NotImplementedError("Streaming is not supported yet.")
def close(self):
self.engine.shutdown()
import re
from ...libs.cut_image import cut_image_and_table
from ...libs.enum_class import BlockType, ContentType
from ...libs.hash_utils import str_md5
from ...libs.magic_model import fix_two_layer_blocks
from ...libs.version import __version__
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息"""
# 解析token,提取坐标和类型
# 假设token格式为:<|box_start|>x0 y0 x1 y1<|box_end|><|ref_start|>type<|ref_end|><|md_start|>content<|md_end|>
# 这里需要根据实际的token格式进行解析
# 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
scale = image_dict["scale"]
page_pil_img = image_dict["img_pil"]
page_img_md5 = str_md5(image_dict["img_base64"])
width, height = map(int, page.get_size())
# 使用正则表达式查找所有块
pattern = (
r"<\|box_start\|>(.*?)<\|box_end\|><\|ref_start\|>(.*?)<\|ref_end\|><\|md_start\|>(.*?)(?:<\|md_end\|>|<\|im_end\|>)"
)
block_infos = re.findall(pattern, token, re.DOTALL)
blocks = []
# 解析每个块
for index, block_info in enumerate(block_infos):
block_bbox = block_info[0].strip()
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
span_type = "unknown"
if block_type in [
"text",
"title",
"image_caption",
"image_footnote",
"table_caption",
"table_footnote",
"list",
"index",
]:
span_type = ContentType.TEXT
elif block_type in ["image"]:
block_type = BlockType.IMAGE_BODY
span_type = ContentType.IMAGE
elif block_type in ["table"]:
block_type = BlockType.TABLE_BODY
span_type = ContentType.TABLE
elif block_type in ["equation"]:
block_type = BlockType.INTERLINE_EQUATION
span_type = ContentType.INTERLINE_EQUATION
if span_type in ["image", "table"]:
span = {
"bbox": block_bbox,
"type": span_type,
}
if span_type == ContentType.TABLE:
span["html"] = block_content
span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
else:
span = {
"bbox": block_bbox,
"type": span_type,
"content": block_content,
}
line = {
"bbox": block_bbox,
"spans": [span],
}
blocks.append(
{
"bbox": block_bbox,
"type": block_type,
"lines": [line],
"index": index,
}
)
image_blocks = fix_two_layer_blocks(blocks, BlockType.IMAGE)
table_blocks = fix_two_layer_blocks(blocks, BlockType.TABLE)
page_blocks = [
block
for block in blocks
if block["type"] in [BlockType.TEXT, BlockType.TITLE, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]
]
page_blocks.extend([*image_blocks, *table_blocks])
# 对page_blocks根据index的值进行排序
page_blocks.sort(key=lambda x: x["index"])
page_info = {"para_blocks": page_blocks, "page_size": [width, height], "page_idx": page_index}
return page_info
def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
middle_json = {"pdf_info": [], "_version_name": __version__}
for index, token in enumerate(token_list):
page = pdf_doc[index]
image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
return middle_json
if __name__ == "__main__":
output = r"<|box_start|>088 119 472 571<|box_end|><|ref_start|>image<|ref_end|><|md_start|>![]('img_url')<|md_end|>\n<|box_start|>079 582 482 608<|box_end|><|ref_start|>image_caption<|ref_end|><|md_start|>Fig. 2. (a) Schematic of the change in the FDC over time, and (b) definition of model parameters.<|md_end|>\n<|box_start|>079 624 285 638<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.2. Zero flow day analysis<|md_end|>\n<|box_start|>079 656 482 801<|box_end|><|ref_start|>text<|ref_end|><|md_start|>A notable feature of Fig. 1 is the increase in the number of zero flow days. A similar approach to Eq. (2), using an inverse sigmoidal function was employed to assess the impact of afforestation on the number of zero flow days per year \((N_{\mathrm{zero}})\). In this case, the left hand side of Eq. (2) is replaced by \(N_{\mathrm{zero}}\) and \(b\) and \(S\) are constrained to negative as \(N_{\mathrm{zero}}\) decreases as rainfall increases, and increases with plantation growth:<|md_end|>\n<|box_start|>076 813 368 853<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nN_{\mathrm{zero}}=a+b(\Delta P)+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}}{S}\right)}\n\]<|md_end|>\n<|box_start|>079 865 482 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For the average pre-treatment condition \(\Delta P=0\) and \(T=0\), \(N_{\mathrm{zero}}\) approximately equals \(a\). \(Y\) gives<|md_end|>\n<|box_start|>525 119 926 215<|box_end|><|ref_start|>text<|ref_end|><|md_start|>the magnitude of change in zero flow days due to afforestation, and \(S\) describes the shape of the response. For the average climate condition \(\Delta P=0\), \(a+Y\) becomes the number of zero flow days when the new equilibrium condition under afforestation is reached.<|md_end|>\n<|box_start|>525 240 704 253<|box_end|><|ref_start|>title<|ref_end|><|md_start|># 2.3. Statistical analyses<|md_end|>\n<|box_start|>525 271 926 368<|box_end|><|ref_start|>text<|ref_end|><|md_start|>The coefficient of efficiency \((E)\) (Nash and Sutcliffe, 1970; Chiew and McMahon, 1993; Legates and McCabe, 1999) was used as the 'goodness of fit' measure to evaluate the fit between observed and predicted flow deciles (2) and zero flow days (3). \(E\) is given by:<|md_end|>\n<|box_start|>520 375 735 415<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nE=1.0-\frac{\sum_{i=1}^{N}(O_{i}-P_{i})^{2}}{\sum_{i=1}^{N}(O_{i}-\bar{O})^{2}}\n\]<|md_end|>\n<|box_start|>525 424 926 601<|box_end|><|ref_start|>text<|ref_end|><|md_start|>where \(O\) are observed data, \(P\) are predicted values, and \(\bar{O}\) is the mean for the entire period. \(E\) is unity minus the ratio of the mean square error to the variance in the observed data, and ranges from \(-\infty\) to 1.0. Higher values indicate greater agreement between observed and predicted data as per the coefficient of determination \((r^{2})\). \(E\) is used in preference to \(r^{2}\) in evaluating hydrologic modelling because it is a measure of the deviation from the 1:1 line. As \(E\) is always \(<r^{2}\) we have arbitrarily considered \(E>0.7\) to indicate adequate model fits.<|md_end|>\n<|box_start|>525 603 926 731<|box_end|><|ref_start|>text<|ref_end|><|md_start|>It is important to assess the significance of the model parameters to check the model assumptions that rainfall and forest age are driving changes in the FDC. The model (2) was split into simplified forms, where only the rainfall or time terms were included by setting \(b=0\), as shown in Eq. (5), or \(Y=0\) as shown in Eq. (6). The component models (5) and (6) were then tested against the complete model, (2).<|md_end|>\n<|box_start|>520 739 735 778<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+\frac{Y}{1+\exp\left(\frac{T-T_{\mathrm{half}}^{\prime}}{S}\right)}\n\]<|md_end|>\n<|box_start|>525 787 553 799<|box_end|><|ref_start|>text<|ref_end|><|md_start|>and<|md_end|>\n<|box_start|>520 807 646 825<|box_end|><|ref_start|>equation<|ref_end|><|md_start|>\[\nQ_{\%}=a+b\Delta P\n\]<|md_end|>\n<|box_start|>525 833 926 895<|box_end|><|ref_start|>text<|ref_end|><|md_start|>For both the flow duration curve analysis and zero flow days analysis, a \(t\)-test was then performed to test whether (5) and (6) were significantly different to (2). A critical value of \(t\) exceeding the calculated \(t\)-value<|md_end|><|im_end|>"
p_info = token_to_page_info(output)
# 将blocks 转换为json文本
import json
json_str = json.dumps(p_info, ensure_ascii=False, indent=4)
print(json_str)
import os
import re
from base64 import b64decode
import httpx
_timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
_file_exts = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".pdf")
_data_uri_regex = re.compile(r"^data:[^;,]+;base64,")
def load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
response = httpx.get(uri, timeout=_timeout)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
async def aio_load_resource(uri: str) -> bytes:
if uri.startswith("http://") or uri.startswith("https://"):
async with httpx.AsyncClient(timeout=_timeout) as client:
response = await client.get(uri)
return response.content
if uri.startswith("file://"):
with open(uri[len("file://") :], "rb") as file:
return file.read()
if uri.lower().endswith(_file_exts):
with open(uri, "rb") as file:
return file.read()
if re.match(_data_uri_regex, uri):
return b64decode(uri.split(",")[1])
return b64decode(uri)
# Copyright (c) Opendatalab. All rights reserved.
import time
from loguru import logger
from ...data.data_reader_writer import DataWriter
from ...libs.pdf_image_tools import load_images_from_pdf
from .base_predictor import BasePredictor
from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json
class ModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_model(
self,
backend: str,
model_path: str | None,
server_url: str | None,
) -> BasePredictor:
key = (backend,)
if key not in self._models:
self._models[key] = get_predictor(
backend=backend,
model_path=model_path,
server_url=server_url,
)
return self._models[key]
def doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="huggingface",
model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
load_images_time = round(time.time() - load_images_start, 2)
logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = predictor.batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json, results
async def aio_doc_analyze(
pdf_bytes,
image_writer: DataWriter | None,
predictor: BasePredictor | None = None,
backend="huggingface",
model_path="jinzhenj/OEEzRkQ3RTAtMDMx-0415", # TODO: change to formal path after release.
server_url: str | None = None,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
load_images_time = round(time.time() - load_images_start, 2)
logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
results = await predictor.aio_batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json
from ..model.vlm_sglang_model.server import main
if __name__ == "__main__":
main()
import math
def is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (
x0_1 >= x0_2 # box1的左边界不在box2的左边外
and y0_1 >= y0_2 # box1的上边界不在box2的上边外
and x1_1 <= x1_2 # box1的右边界不在box2的右边外
and y1_1 <= y1_2
) # box1的下边界不在box2的下边外
def bbox_relative_pos(bbox1, bbox2):
"""判断两个矩形框的相对位置关系.
Args:
bbox1: 一个四元组,表示第一个矩形框的左上角和右下角的坐标,格式为(x1, y1, x1b, y1b)
bbox2: 一个四元组,表示第二个矩形框的左上角和右下角的坐标,格式为(x2, y2, x2b, y2b)
Returns:
一个四元组,表示矩形框1相对于矩形框2的位置关系,格式为(left, right, bottom, top)
其中,left表示矩形框1是否在矩形框2的左侧,right表示矩形框1是否在矩形框2的右侧,
bottom表示矩形框1是否在矩形框2的下方,top表示矩形框1是否在矩形框2的上方
"""
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left = x2b < x1
right = x1b < x2
bottom = y2b < y1
top = y1b < y2
return left, right, bottom, top
def bbox_distance(bbox1, bbox2):
"""计算两个矩形框的距离。
Args:
bbox1 (tuple): 第一个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (tuple): 第二个矩形框的坐标,格式为 (x1, y1, x2, y2),其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
Returns:
float: 矩形框之间的距离。
"""
def dist(point1, point2):
return math.sqrt((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2)
x1, y1, x1b, y1b = bbox1
x2, y2, x2b, y2b = bbox2
left, right, bottom, top = bbox_relative_pos(bbox1, bbox2)
if top and left:
return dist((x1, y1b), (x2b, y2))
elif left and bottom:
return dist((x1, y1), (x2b, y2b))
elif bottom and right:
return dist((x1b, y1), (x2, y2b))
elif right and top:
return dist((x1b, y1b), (x2, y2))
elif left:
return x1 - x2b
elif right:
return x2 - x1b
elif bottom:
return y1 - y2b
elif top:
return y2 - y1b
return 0.0
from loguru import logger
from .pdf_image_tools import cut_image
def cut_image_and_table(span, page_pil_img, page_img_md5, page_id, imageWriter, scale=2):
def return_path(path_type):
return f"{path_type}/{page_img_md5}"
span_type = span["type"]
if not check_img_bbox(span["bbox"]) or not imageWriter:
span["image_path"] = ""
else:
span["image_path"] = cut_image(
span["bbox"], page_id, page_pil_img, return_path=return_path(span_type), imageWriter=imageWriter, scale=scale
)
return span
def check_img_bbox(bbox) -> bool:
if any([bbox[0] >= bbox[2], bbox[1] >= bbox[3]]):
logger.warning(f"image_bboxes: 错误的box, {bbox}")
return False
return True
import json
from io import BytesIO
from PyPDF2 import PdfReader, PdfWriter
from reportlab.pdfgen import canvas
from .enum_class import BlockType
def draw_bbox_without_number(i, bbox_list, page, c, rgb_config, fill_config):
new_rgb = [float(color) / 255 for color in rgb_config]
page_data = bbox_list[i]
page_width, page_height = page.cropbox[2], page.cropbox[3]
for bbox in page_data:
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
rect = [bbox[0], page_height - bbox[3], width, height] # Define the rectangle
if fill_config: # filled rectangle
c.setFillColorRGB(new_rgb[0], new_rgb[1], new_rgb[2], 0.3)
c.rect(rect[0], rect[1], rect[2], rect[3], stroke=0, fill=1)
else: # bounding box
c.setStrokeColorRGB(new_rgb[0], new_rgb[1], new_rgb[2])
c.rect(rect[0], rect[1], rect[2], rect[3], stroke=1, fill=0)
return c
def draw_bbox_with_number(i, bbox_list, page, c, rgb_config, fill_config, draw_bbox=True):
new_rgb = [float(color) / 255 for color in rgb_config]
page_data = bbox_list[i]
# 强制转换为 float
page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
for j, bbox in enumerate(page_data):
# 确保bbox的每个元素都是float
x0, y0, x1, y1 = map(float, bbox)
width = x1 - x0
height = y1 - y0
rect = [x0, page_height - y1, width, height]
if draw_bbox:
if fill_config:
c.setFillColorRGB(*new_rgb, 0.3)
c.rect(rect[0], rect[1], rect[2], rect[3], stroke=0, fill=1)
else:
c.setStrokeColorRGB(*new_rgb)
c.rect(rect[0], rect[1], rect[2], rect[3], stroke=1, fill=0)
c.setFillColorRGB(*new_rgb, 1.0)
c.setFontSize(size=10)
# 这里也要用float
c.drawString(x1 + 2, page_height - y0 - 10, str(j + 1))
return c
def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
# dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
imgs_list, imgs_body_list, imgs_caption_list = [], [], []
imgs_footnote_list = []
titles_list = []
texts_list = []
interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info:
# page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
lists = []
indices = []
# for dropped_bbox in page['discarded_blocks']:
# page_dropped_list.append(dropped_bbox['bbox'])
# dropped_bbox_list.append(page_dropped_list)
for block in page["para_blocks"]:
bbox = block["bbox"]
if block["type"] == BlockType.TABLE:
tables.append(bbox)
for nested_block in block["blocks"]:
bbox = nested_block["bbox"]
if nested_block["type"] == BlockType.TABLE_BODY:
tables_body.append(bbox)
elif nested_block["type"] == BlockType.TABLE_CAPTION:
tables_caption.append(bbox)
elif nested_block["type"] == BlockType.TABLE_FOOTNOTE:
tables_footnote.append(bbox)
elif block["type"] == BlockType.IMAGE:
imgs.append(bbox)
for nested_block in block["blocks"]:
bbox = nested_block["bbox"]
if nested_block["type"] == BlockType.IMAGE_BODY:
imgs_body.append(bbox)
elif nested_block["type"] == BlockType.IMAGE_CAPTION:
imgs_caption.append(bbox)
elif nested_block["type"] == BlockType.IMAGE_FOOTNOTE:
imgs_footnote.append(bbox)
elif block["type"] == BlockType.TITLE:
titles.append(bbox)
elif block["type"] == BlockType.TEXT:
texts.append(bbox)
elif block["type"] == BlockType.INTERLINE_EQUATION:
interequations.append(bbox)
elif block["type"] == BlockType.LIST:
lists.append(bbox)
elif block["type"] == BlockType.INDEX:
indices.append(bbox)
tables_list.append(tables)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
tables_footnote_list.append(tables_footnote)
imgs_list.append(imgs)
imgs_body_list.append(imgs_body)
imgs_caption_list.append(imgs_caption)
imgs_footnote_list.append(imgs_footnote)
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indices)
layout_bbox_list = []
table_type_order = {"table_caption": 1, "table_body": 2, "table_footnote": 3}
for page in pdf_info:
page_block_list = []
for block in page["para_blocks"]:
if block["type"] in [
BlockType.TEXT,
BlockType.TITLE,
BlockType.INTERLINE_EQUATION,
BlockType.LIST,
BlockType.INDEX,
]:
bbox = block["bbox"]
page_block_list.append(bbox)
elif block["type"] in [BlockType.IMAGE]:
for sub_block in block["blocks"]:
bbox = sub_block["bbox"]
page_block_list.append(bbox)
elif block["type"] in [BlockType.TABLE]:
sorted_blocks = sorted(block["blocks"], key=lambda x: table_type_order[x["type"]])
for sub_block in sorted_blocks:
bbox = sub_block["bbox"]
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_bytes_io = BytesIO(pdf_bytes)
pdf_docs = PdfReader(pdf_bytes_io)
output_pdf = PdfWriter()
for i, page in enumerate(pdf_docs.pages):
# 获取原始页面尺寸
page_width, page_height = float(page.cropbox[2]), float(page.cropbox[3])
custom_page_size = (page_width, page_height)
packet = BytesIO()
# 使用原始PDF的尺寸创建canvas
c = canvas.Canvas(packet, pagesize=custom_page_size)
# c = draw_bbox_without_number(i, dropped_bbox_list, page, c, [158, 158, 158], True)
c = draw_bbox_without_number(i, tables_body_list, page, c, [204, 204, 0], True)
c = draw_bbox_without_number(i, tables_caption_list, page, c, [255, 255, 102], True)
c = draw_bbox_without_number(i, tables_footnote_list, page, c, [229, 255, 204], True)
c = draw_bbox_without_number(i, imgs_body_list, page, c, [153, 255, 51], True)
c = draw_bbox_without_number(i, imgs_caption_list, page, c, [102, 178, 255], True)
c = draw_bbox_without_number(i, imgs_footnote_list, page, c, [255, 178, 102], True)
c = draw_bbox_without_number(i, titles_list, page, c, [102, 102, 255], True)
c = draw_bbox_without_number(i, texts_list, page, c, [153, 0, 76], True)
c = draw_bbox_without_number(i, interequations_list, page, c, [0, 255, 0], True)
c = draw_bbox_without_number(i, lists_list, page, c, [40, 169, 92], True)
c = draw_bbox_without_number(i, indexs_list, page, c, [40, 169, 92], True)
c = draw_bbox_with_number(i, layout_bbox_list, page, c, [255, 0, 0], False, draw_bbox=False)
c.save()
packet.seek(0)
overlay_pdf = PdfReader(packet)
page.merge_page(overlay_pdf.pages[0])
output_pdf.add_page(page)
# 保存结果
with open(f"{out_path}/{filename}", "wb") as f:
output_pdf.write(f)
if __name__ == "__main__":
# 读取PDF文件
pdf_path = "examples/demo1.pdf"
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
# 从json文件读取pdf_info
json_path = "examples/demo1_1746005777.0863056_middle.json"
with open(json_path, "r", encoding="utf-8") as f:
pdf_ann = json.load(f)
pdf_info = pdf_ann["pdf_info"]
# 调用可视化函数,输出到examples目录
draw_layout_bbox(pdf_info, pdf_bytes, "examples", "output_with_layout.pdf")
class BlockType:
IMAGE = 'image'
TABLE = 'table'
IMAGE_BODY = 'image_body'
TABLE_BODY = 'table_body'
IMAGE_CAPTION = 'image_caption'
TABLE_CAPTION = 'table_caption'
IMAGE_FOOTNOTE = 'image_footnote'
TABLE_FOOTNOTE = 'table_footnote'
TEXT = 'text'
TITLE = 'title'
INTERLINE_EQUATION = 'interline_equation'
LIST = 'list'
INDEX = 'index'
class ContentType:
IMAGE = 'image'
TABLE = 'table'
TEXT = 'text'
INTERLINE_EQUATION = 'interline_equation'
class MakeMode:
MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format'
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved.
import hashlib
import json
def bytes_md5(file_bytes):
hasher = hashlib.md5()
hasher.update(file_bytes)
return hasher.hexdigest().upper()
def str_md5(input_string):
hasher = hashlib.md5()
# 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
input_bytes = input_string.encode('utf-8')
hasher.update(input_bytes)
return hasher.hexdigest()
def str_sha256(input_string):
hasher = hashlib.sha256()
# 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
input_bytes = input_string.encode('utf-8')
hasher.update(input_bytes)
return hasher.hexdigest()
def dict_md5(d):
json_str = json.dumps(d, sort_keys=True, ensure_ascii=False)
return hashlib.md5(json_str.encode('utf-8')).hexdigest()
\ No newline at end of file
from typing import Literal
from .boxbase import bbox_distance, is_in
def __reduct_overlap(bboxes):
N = len(bboxes)
keep = [True] * N
for i in range(N):
for j in range(N):
if i == j:
continue
if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]]
def __tie_up_category_by_distance_v3(
blocks: list,
subject_block_type: str,
object_block_type: str,
):
subjects = __reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
filter(
lambda x: x["type"] == subject_block_type,
blocks,
),
)
)
)
objects = __reduct_overlap(
list(
map(
lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
filter(
lambda x: x["type"] == object_block_type,
blocks,
),
)
)
)
ret = []
N, M = len(subjects), len(objects)
subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
OBJ_IDX_OFFSET = 10000
SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
(i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
]
seen_idx = set()
seen_sub_idx = set()
while N > len(seen_sub_idx):
candidates = []
for idx, kind, x0, y0 in all_boxes_with_idx:
if idx in seen_idx:
continue
candidates.append((idx, kind, x0, y0))
if len(candidates) == 0:
break
left_x = min([v[2] for v in candidates])
top_y = min([v[3] for v in candidates])
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
fst_idx, fst_kind, left_x, top_y = candidates[0]
candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
nxt = None
for i in range(1, len(candidates)):
if candidates[i][1] ^ fst_kind == 1:
nxt = candidates[i]
break
if nxt is None:
break
if fst_kind == SUB_BIT_KIND:
sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
else:
sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
nearest_dis = float("inf")
for i in range(N):
if i in seen_idx or i == sub_idx:
continue
nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
if pair_dis >= 3 * nearest_dis:
seen_idx.add(sub_idx)
continue
seen_idx.add(sub_idx)
seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
seen_sub_idx.add(sub_idx)
ret.append(
{
"sub_bbox": {
"bbox": subjects[sub_idx]["bbox"],
"lines": subjects[sub_idx]["lines"],
"index": subjects[sub_idx]["index"],
},
"obj_bboxes": [
{"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
],
"sub_idx": sub_idx,
}
)
for i in range(len(objects)):
j = i + OBJ_IDX_OFFSET
if j in seen_idx:
continue
seen_idx.add(j)
nearest_dis, nearest_sub_idx = float("inf"), -1
for k in range(len(subjects)):
dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
if dis < nearest_dis:
nearest_dis = dis
nearest_sub_idx = k
for k in range(len(subjects)):
if k != nearest_sub_idx:
continue
if k in seen_sub_idx:
for kk in range(len(ret)):
if ret[kk]["sub_idx"] == k:
ret[kk]["obj_bboxes"].append(
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
)
break
else:
ret.append(
{
"sub_bbox": {
"bbox": subjects[k]["bbox"],
"lines": subjects[k]["lines"],
"index": subjects[k]["index"],
},
"obj_bboxes": [
{"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
],
"sub_idx": k,
}
)
seen_sub_idx.add(k)
seen_idx.add(k)
for i in range(len(subjects)):
if i in seen_sub_idx:
continue
ret.append(
{
"sub_bbox": {
"bbox": subjects[i]["bbox"],
"lines": subjects[i]["lines"],
"index": subjects[i]["index"],
},
"obj_bboxes": [],
"sub_idx": i,
}
)
return ret
def get_type_blocks(blocks, block_type: Literal["image", "table"]):
with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
ret = []
for v in with_captions:
record = {
f"{block_type}_body": v["sub_bbox"],
f"{block_type}_caption_list": v["obj_bboxes"],
}
filter_idx = v["sub_idx"]
d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
ret.append(record)
return ret
def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
need_fix_blocks = get_type_blocks(blocks, fix_type)
fixed_blocks = []
for block in need_fix_blocks:
body = block[f"{fix_type}_body"]
caption_list = block[f"{fix_type}_caption_list"]
footnote_list = block[f"{fix_type}_footnote_list"]
body["type"] = f"{fix_type}_body"
for caption in caption_list:
caption["type"] = f"{fix_type}_caption"
for footnote in footnote_list:
footnote["type"] = f"{fix_type}_footnote"
two_layer_block = {
"type": fix_type,
"bbox": body["bbox"],
"blocks": [
body,
],
"index": body["index"],
}
two_layer_block["blocks"].extend([*caption_list, *footnote_list])
fixed_blocks.append(two_layer_block)
return fixed_blocks
# Copyright (c) Opendatalab. All rights reserved.
from io import BytesIO
import pypdfium2 as pdfium
from loguru import logger
from PIL import Image
from ..data.data_reader_writer import FileBasedDataWriter
from ..utils.pdf_reader import image_to_b64str, image_to_bytes, page_to_image
from .hash_utils import str_sha256
def pdf_page_to_image(page: pdfium.PdfPage, dpi=200) -> dict:
"""Convert pdfium.PdfDocument to image, Then convert the image to base64.
Args:
page (_type_): pdfium.PdfPage
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img_base64': str, 'img_pil': pil_img, 'scale': float }
"""
pil_img, scale = page_to_image(page, dpi=dpi)
img_base64 = image_to_b64str(pil_img)
image_dict = {
"img_base64": img_base64,
"img_pil": pil_img,
"scale": scale,
}
return image_dict
def load_images_from_pdf(
pdf_bytes: bytes,
dpi=200,
start_page_id=0,
end_page_id=None,
):
images_list = []
pdf_doc = pdfium.PdfDocument(pdf_bytes)
pdf_page_num = len(pdf_doc)
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
for index in range(0, pdf_page_num):
if start_page_id <= index <= end_page_id:
page = pdf_doc[index]
image_dict = pdf_page_to_image(page, dpi=dpi)
images_list.append(image_dict)
return images_list, pdf_doc
def cut_image(bbox: tuple, page_num: int, page_pil_img, return_path, imageWriter: FileBasedDataWriter, scale=3):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
# 老版本返回不带bucket的路径
img_path = f"{return_path}_{filename}" if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f"{str_sha256(img_path)}.jpg"
# img_hash256_path = f'{img_path}.jpg'
crop_img = get_crop_img(bbox, page_pil_img, scale=scale)
img_bytes = image_to_bytes(crop_img, image_format="JPEG")
imageWriter.write(img_hash256_path, img_bytes)
return img_hash256_path
def get_crop_img(bbox: tuple, pil_img, scale=2):
scale_bbox = (
int(bbox[0] * scale),
int(bbox[1] * scale),
int(bbox[2] * scale),
int(bbox[3] * scale),
)
return pil_img.crop(scale_bbox)
def images_bytes_to_pdf_bytes(image_bytes):
# 内存缓冲区
pdf_buffer = BytesIO()
# 载入并转换所有图像为 RGB 模式
image = Image.open(BytesIO(image_bytes)).convert("RGB")
# 第一张图保存为 PDF,其余追加
image.save(pdf_buffer, format="PDF", save_all=True)
# 获取 PDF bytes 并重置指针(可选)
pdf_bytes = pdf_buffer.getvalue()
pdf_buffer.close()
return pdf_bytes
__version__ = "0.0.1"
\ No newline at end of file
from transformers import AutoConfig, AutoImageProcessor, AutoModelForCausalLM
from .configuration_mineru2 import Mineru2QwenConfig
from .image_processing_mineru2 import Mineru2ImageProcessor
from .modeling_mineru2 import Mineru2QwenForCausalLM
AutoConfig.register(Mineru2QwenConfig.model_type, Mineru2QwenConfig)
AutoModelForCausalLM.register(Mineru2QwenConfig, Mineru2QwenForCausalLM)
AutoImageProcessor.register(Mineru2QwenConfig, slow_image_processor_class=Mineru2ImageProcessor)
from transformers import Qwen2Config
class Mineru2QwenConfig(Qwen2Config):
model_type = "mineru2_qwen"
def __init__(
self,
ignore_index=-100,
image_aspect_ratio="square_anyres_max_9",
image_grid_pinpoints="(1x1),...,(4x4)",
image_token_index=151646,
mm_hidden_size=1152,
mm_patch_merge_type="spatial_unpad",
mm_projector_type="mlp2x_gelu",
mm_vision_select_feature="full",
mm_vision_select_layer=-2,
mm_vision_tower="google/siglip-so400m-patch14-384",
tie_word_embeddings=False,
tokenizer_model_max_length=16384,
tokenizer_padding_side="right",
unfreeze_mm_vision_tower=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.image_token_index = image_token_index
self.mm_hidden_size = mm_hidden_size
self.mm_patch_merge_type = mm_patch_merge_type
self.mm_projector_type = mm_projector_type
self.mm_vision_select_feature = mm_vision_select_feature
self.mm_vision_select_layer = mm_vision_select_layer
self.mm_vision_tower = mm_vision_tower
self.tokenizer_model_max_length = tokenizer_model_max_length
self.tokenizer_padding_side = tokenizer_padding_side
self.unfreeze_mm_vision_tower = unfreeze_mm_vision_tower
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment