Commit 8e55a526 authored by Jin Zhen Jiang's avatar Jin Zhen Jiang
Browse files

feat: add mineru-vlm backend.

parent 6f8a9610
import ast
import math
import re
from functools import partial, reduce
from typing import Dict, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.utils import TensorType
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
original_width, original_height = original_size
best_fit = (0, 0)
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def divide_to_patches(image, patch_size):
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
if pil_img.mode == "L":
pil_img = pil_img.convert("RGB")
if width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
# This functions is not used.
def resize_and_pad_image(image, target_resolution):
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints) # type: ignore
best_resolution = select_best_resolution(image.size, possible_resolutions)
# image_padded = resize_and_pad_image(image, best_resolution)
image_padded = image.resize(best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", "")
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors="pt")["pixel_values"]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
class Mineru2ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
size=(384, 384),
crop_size: Optional[Dict[str, int]] = None,
resample=PILImageResampling.BICUBIC,
rescale_factor=1 / 255,
data_format=ChannelDimension.FIRST,
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[list] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
self.image_aspect_ratio = image_aspect_ratio
self.image_grid_pinpoints = image_grid_pinpoints
self.in_e2e_processing = False
def _preprocess(self, images):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
return {"pixel_values": images}
def _preprocess_end_to_end(self, images):
image_aspect_ratio = self.image_aspect_ratio
image_grid_pinpoints = self.image_grid_pinpoints
assert image_aspect_ratio is not None
assert image_grid_pinpoints is not None
pixel_values = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in self.image_mean))
image = self._preprocess(image)["pixel_values"][0]
pixel_values.append(image)
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, self, self.image_grid_pinpoints)
pixel_values.append(image.numpy())
else:
pixel_values = self._preprocess(images)["pixel_values"]
if isinstance(pixel_values, list) and all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = np.stack(pixel_values, axis=0)
# CAUTION: here used (height, width).
image_sizes = [(image.height, image.width) for image in images]
assert len(pixel_values) == len(image_sizes)
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
def preprocess(
self,
images,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
if self.image_aspect_ratio is None or self.in_e2e_processing:
data = self._preprocess(images)
else:
assert self.image_grid_pinpoints is not None
self.in_e2e_processing = True
try:
data = self._preprocess_end_to_end(images)
finally:
self.in_e2e_processing = False
return BatchFeature(data=data, tensor_type=return_tensors)
This diff is collapsed.
from sglang.srt.configs.model_config import multimodal_model_archs
from sglang.srt.models.registry import ModelRegistry
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processor import (
PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processor import (
IMAGE_PROCESSOR_MAPPING as PROCESSOR_MAPPING,
)
from .. import vlm_hf_model as _
from .image_processor import Mineru2ImageProcessor
from .model import Mineru2QwenForCausalLM
ModelRegistry.models[Mineru2QwenForCausalLM.__name__] = Mineru2QwenForCausalLM
PROCESSOR_MAPPING[Mineru2QwenForCausalLM] = Mineru2ImageProcessor
multimodal_model_archs.append(Mineru2QwenForCausalLM.__name__)
import asyncio
import time
from types import MethodType
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import fastapi
from sglang.srt.entrypoints.engine import Engine as _Engine
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import (
TokenizerManager,
dataclass_to_string_truncated,
logger,
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from ...utils.run_async import run_async
from .logit_processor import Mineru2LogitProcessor
class BatchEngine(_Engine):
"""
The engine is patched to support batch multi-modal generate, and early image preprocessing.
"""
def __init__(self, server_args: ServerArgs, **kwargs):
server_args.enable_custom_logit_processor = True
super().__init__(server_args=server_args, **kwargs)
_patch_tokenizer_manager(self.tokenizer_manager)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream:
def generator_wrapper():
while True:
try:
chunk = run_async(generator.__anext__())
yield chunk
except StopAsyncIteration:
break
return generator_wrapper()
else:
ret = run_async(generator.__anext__())
return ret
async def async_generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Union[List[Dict], Dict]] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
return_hidden_states: bool = False,
stream: bool = False,
) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
modalities_list = []
# EDIT
if isinstance(image_data, list):
for _ in range(len(image_data)):
modalities_list.append(["image"])
elif image_data is not None:
modalities_list.append("image")
# ADD
if custom_logit_processor is None:
custom_logit_processor = Mineru2LogitProcessor().to_str()
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
image_data=image_data,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
return_hidden_states=return_hidden_states,
stream=stream,
)
generator = _generate_request(self.tokenizer_manager, obj, None)
if stream is True:
return generator
else:
return await generator.__anext__()
def _auto_create_handle_loop(self: TokenizerManager):
"""
patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
when the event loop changes.
"""
try:
curr_handle_loop = asyncio.get_running_loop()
except RuntimeError:
curr_handle_loop = None
last_handle_loop = getattr(self, "_last_handle_loop", None)
if last_handle_loop != curr_handle_loop:
self.no_create_loop = False
setattr(self, "_last_handle_loop", curr_handle_loop)
return TokenizerManager.auto_create_handle_loop(self)
def _patch_tokenizer_manager(self: TokenizerManager):
self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
async def _one_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request],
created_time: Optional[float],
):
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for out in self._wait_one_response(obj, request):
yield out
async def _handle_batch_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) != 1:
raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
generators.append(_one_request(self, tmp_obj, request, created_time))
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
async def _generate_request(
self: TokenizerManager,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, request):
yield response
else:
async for response in _handle_batch_request(self, obj, request, created_time):
yield response
import ast
import asyncio
import re
from typing import List, Optional, Union
import numpy as np
try:
# sglang==0.4.5.post3
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as BaseProcessor,
)
get_global_processor = None
except ImportError:
# sglang==0.4.4.post1
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as BaseProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
from .model import Mineru2QwenForCausalLM
# image_best_res is only resized (not padded).
def process_anyres_image(image, processor, grid_pinpoints):
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
patch_size = processor.crop_size["height"]
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [
(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
patches = divide_to_patches(image_best_res, processor.crop_size["height"])
image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
return np.stack(image_patches, axis=0)
class Mineru2ImageProcessor(BaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
if image_processor is None:
assert get_global_processor is not None
image_processor = get_global_processor().image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
else:
pixel_values = image_processor(image)["pixel_values"][0]
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
if hasattr(self, "cpu_executor"):
executor = self.cpu_executor
else:
executor = self.executor
if get_global_processor is not None:
image_processor = None # save ipc cost
else:
image_processor = self._processor.image_processor
if executor is not None:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
Mineru2ImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
else:
return self._process_single_image_task(
image_data,
aspect_ratio,
grid_pinpoints,
image_processor,
)
# sglang==0.4.4.post1
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", "")
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data, aspect_ratio, grid_pinpoints))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
# sglang==0.4.5.post3
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
result = await self.process_images_async(image_data, input_text, request_obj, *args, **kwargs)
if result is None:
return None
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"mm_items": [
MultimodalDataItem(
pixel_values=result["pixel_values"],
image_sizes=result["image_sizes"],
modality=modality,
)
],
}
ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}
from typing import List
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
class Mineru2LogitProcessor(CustomLogitProcessor):
"""
Stateless logit processor for Mineru2.
(base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
This processor applies token-level constraints to prevent repetition during generation.
It supports two main constraints:
- no_repeat_ngram_size (int):
Prevents repeating the same n-gram of specified size in the output.
Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
This implementation is slower due to its lack of specialized optimization.
- no_repeat_token_count (int):
(Placeholder for future logic)
Intended to prevent repeating the same token multiple times.
Not yet implemented in this version.
"""
def __init__(self) -> None:
super().__init__()
self._generated_ngrams = {} # Cache of generated n-grams by request ID
self._time = {} # Timestamp of the last update for each request
self._gen_step = 0 # Global generation step counter
def __call__(self, logits, batch_info: List[dict]):
"""
Applies repetition constraints to the logits before sampling tokens.
Args:
logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
- "__req__": Request object containing request ID and output_ids.
- "no_repeat_ngram_size": Size of n-gram to avoid repeating.
Returns:
FloatTensor: The modified logits tensor with banned token logits set to -inf.
"""
from sglang.srt.managers.schedule_batch import Req
self._gen_step += 1 # Update global generation step
for idx, info in enumerate(batch_info):
if not isinstance(info, dict) or "__req__" not in info:
continue
req: Req = info["__req__"]
rid = req.rid
output_ids = req.output_ids
ngram_size = info.get("no_repeat_ngram_size", 0)
# Skip if there are not enough tokens to form an n-gram
if ngram_size <= 0 or len(output_ids) < ngram_size:
continue
# Record the current step for cache cleanup tracking
self._time[rid] = self._gen_step
# Initialize n-gram cache for this request if it doesn't exist
if rid not in self._generated_ngrams:
self._generated_ngrams[rid] = {}
# Get the n-gram prefix (all but the last token)
prev_ngram = tuple(output_ids[-ngram_size:-1])
last_token = output_ids[-1]
# Store this n-gram occurrence
self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
# Get the next-token candidates to ban based on current prefix
current_prefix = tuple(output_ids[-ngram_size + 1 :])
banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
# Set the logits of banned tokens to negative infinity
for token in banned_tokens:
logits[idx][token] = -float("inf")
# Clean up cache for expired requests
expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
for rid in expired_rids:
self._generated_ngrams.pop(rid, None)
self._time.pop(rid, None)
return logits
This diff is collapsed.
import os
import sys
from fastapi import Request
from sglang.srt.entrypoints.http_server import app, generate_request, launch_server
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
from .logit_processor import Mineru2LogitProcessor
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
# remote the existing /generate route
for route in app.routes[:]:
if hasattr(route, "path") and getattr(route, "path") == "/generate":
app.routes.remove(route)
# add the custom /generate route
@app.api_route("/generate", methods=["POST", "PUT"])
async def custom_generate_request(obj: GenerateReqInput, request: Request):
if obj.custom_logit_processor is None:
obj.custom_logit_processor = _custom_logit_processor_str
return await generate_request(obj, request)
def main():
server_args = prepare_server_args(sys.argv[1:])
if server_args.chat_template is None:
server_args.chat_template = "chatml"
server_args.enable_custom_logit_processor = True
try:
launch_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
if __name__ == "__main__":
main()
# Copyright (c) Opendatalab. All rights reserved.
import base64
from io import BytesIO
from loguru import logger
from PIL import Image
from pypdfium2 import PdfBitmap, PdfDocument, PdfPage
def page_to_image(
page: PdfPage,
dpi: int = 144, # changed from 200 to 144
max_width_or_height: int = 2560, # changed from 4500 to 2560
) -> (Image.Image, float):
scale = dpi / 72
long_side_length = max(*page.get_size())
if long_side_length > max_width_or_height:
scale = max_width_or_height / long_side_length
bitmap: PdfBitmap = page.render(scale=scale) # type: ignore
try:
image = bitmap.to_pil()
finally:
try:
bitmap.close()
except Exception:
pass
return image, scale
def image_to_bytes(
image: Image.Image,
image_format: str = "PNG", # 也可以用 "JPEG"
) -> bytes:
with BytesIO() as image_buffer:
image.save(image_buffer, format=image_format)
return image_buffer.getvalue()
def image_to_b64str(
image: Image.Image,
image_format: str = "PNG", # 也可以用 "JPEG"
) -> str:
image_bytes = image_to_bytes(image, image_format)
return base64.b64encode(image_bytes).decode("utf-8")
def pdf_to_images(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
) -> list[Image.Image]:
doc = pdf if isinstance(pdf, PdfDocument) else PdfDocument(pdf)
page_num = len(doc)
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else page_num - 1
if end_page_id > page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = page_num - 1
images = []
try:
for i in range(start_page_id, end_page_id + 1):
image, _ = page_to_image(doc[i], dpi, max_width_or_height)
images.append(image)
finally:
try:
doc.close()
except Exception:
pass
return images
def pdf_to_images_bytes(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
image_format: str = "PNG",
) -> list[bytes]:
images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
return [image_to_bytes(image, image_format) for image in images]
def pdf_to_images_b64strs(
pdf: str | bytes | PdfDocument,
dpi: int = 144,
max_width_or_height: int = 2560,
start_page_id: int = 0,
end_page_id: int | None = None,
image_format: str = "PNG",
) -> list[str]:
images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
return [image_to_b64str(image, image_format) for image in images]
import asyncio
import threading
from queue import Queue
from typing import Any, AsyncIterable, Coroutine, Iterable, TypeVar
T = TypeVar("T")
def run_async(coroutine: Coroutine[Any, Any, T]) -> T:
if not asyncio.iscoroutine(coroutine):
raise ValueError("a coroutine was expected, got {!r}".format(coroutine))
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
return loop.run_until_complete(coroutine)
else:
return asyncio.run(coroutine)
def iter_async(iterable: AsyncIterable[T]) -> Iterable[T]:
if not isinstance(iterable, AsyncIterable):
raise ValueError("an async iterable was expected, got {!r}".format(iterable))
queue = Queue()
async def async_helper():
try:
async for chunk in iterable:
queue.put(chunk)
queue.put(None)
except Exception as e:
queue.put(e)
def helper():
run_async(async_helper())
thread = threading.Thread(target=helper, daemon=True)
thread.start()
while True:
chunk = queue.get()
if chunk is None:
break
if isinstance(chunk, Exception):
raise chunk
yield chunk
thread.join()
[tool.black]
line-length = 128
[tool.ruff]
line-length = 128
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment