Unverified Commit 770529a7 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: support deepseek-ocr (#11891)


Co-authored-by: default avataryhyang201 <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avatarShi Shuai <126407087+shuaills@users.noreply.github.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 39c237f0
from typing import Tuple
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
MIN_CROPS = 2
MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = True
MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
class ImageTransform:
def __init__(
self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "vit_so400m_patch14_siglip_384.webli"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(
self,
model_name: str = "vit_so400m_patch14_siglip_384.webli",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs,
):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(
self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
n_routed_experts=None,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=None,
topk_group=None,
num_experts_per_tok=None,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
use_mla=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = float(rms_norm_eps)
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_mla = use_mla
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekVLV2Config(PretrainedConfig):
# model_type = "deepseek_vl_v2"
model_type = "deepseek-ocr"
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
def __init__(
self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),),
**kwargs,
):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
self.text_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
class DeepseekOCRConfig(DeepseekV2Config):
model_type = "DeepseekOCR"
......@@ -11,6 +11,8 @@ from transformers import (
ProcessorMixin,
)
from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
def select_best_resolution(image_size, candidate_resolutions):
# used for cropping
......@@ -61,6 +63,7 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
......@@ -104,6 +107,68 @@ class ImageTransform(object):
return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
......@@ -133,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
......@@ -176,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
**kwargs,
)
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
......@@ -186,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin):
image_index = 0
image_token_cnt = messages.count(self.image_token)
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
(
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
)
image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images
images_seq_mask += seq_mask
images_spatial_crop += spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
images_spatial_crop = spatial_crop
return (
tokenized_data,
input_ids,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
)
@property
......@@ -251,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
cropping: bool = True,
**kwargs,
):
"""
......@@ -274,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens
"""
assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
prompt = conversations or prompt
(
tokenized_str,
input_ids,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
) = self.format_messages_v2(conversations, images, max_req_input_len)
images_crop,
) = self.format_messages_v2(prompt, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
......@@ -323,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
images_crop=images_crop,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
......@@ -340,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
text: list[str] = None,
**kwargs,
):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one(
prompt=prompt,
prompt=prompt or text,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
......@@ -368,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin):
bos: bool = True,
eos: bool = True,
cropping: bool = True,
max_req_input_len: int = -1,
):
"""Tokenize text with <image> tags."""
images_list, images_seq_mask, images_spatial_crop = [], [], []
conversation = conversation
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres"""
if cropping:
best_width, best_height = select_best_resolution(
image.size, self.candidate_resolutions
)
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
best_width, best_height = self.image_size, self.image_size
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
if cropping:
images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
"""process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(
image,
(self.image_size, self.image_size),
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
"""process the local views"""
local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
"""record height / width crop num"""
num_width_tiles, num_height_tiles = (
best_width // self.image_size,
best_height // self.image_size,
)
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
"""add image tokens"""
h = w = math.ceil(
num_queries = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
# global views tokens h * (w + 1), 1 is for line separator
tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
......@@ -462,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin):
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
......@@ -547,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
......
......@@ -921,6 +921,7 @@ multimodal_model_archs = [
"DotsVLMForCausalLM",
"DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM",
"DeepseekOCRForCausalLM",
]
......
......@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
architectures = resolve_transformers_arch(model_config, architectures)
return ModelRegistry.resolve_model_cls(architectures)
......
This diff is collapsed.
......@@ -200,7 +200,6 @@ _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger = logging.getLogger(__name__)
......
......@@ -178,6 +178,7 @@ class BaseMultimodalProcessor(ABC):
"image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"images_spatial_crop": Modality.IMAGE,
"images_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
......
from typing import List, Union
from sglang.srt.models.deepseek_ocr import DeepseekOCRForCausalLM
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class DeepseekOCRProcessor(BaseMultimodalProcessor):
models = [DeepseekOCRForCausalLM]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
_processor.image_size = 640
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.mm_tokens,
image_data=image_data,
)
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
}
......@@ -838,6 +838,19 @@ register_conv_template(
)
)
register_conv_template(
Conversation(
name="deepseek-ocr",
system_message="",
system_template="",
roles=("", ""),
sep="",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str=["<|end▁of▁sentence|>"],
image_token="<image>",
)
)
register_conv_template(
Conversation(
name="deepseek-vl2",
......@@ -981,6 +994,7 @@ MODEL_TYPE_TO_TEMPLATE = {
"phi4mm": "phi-4-mm",
"minicpmv": "minicpmv",
"minicpmo": "minicpmo",
"deepseek-ocr": "deepseek-ocr",
}
......@@ -1057,3 +1071,11 @@ def match_phi_4_mm(model_path: str):
return "phi-4-mm"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
def match_deepseek_ocr(model_path: str):
if "deepseek-ocr" in model_path.lower():
return "deepseek-ocr"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
......@@ -19,7 +19,7 @@ import os
import tempfile
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download
......@@ -51,26 +51,32 @@ from sglang.srt.configs import (
Qwen3NextConfig,
Step3VLConfig,
)
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig,
Olmo3Config.model_type: Olmo3Config,
Qwen3NextConfig.model_type: Qwen3NextConfig,
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig,
_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
ChatGLMConfig,
DbrxConfig,
ExaoneConfig,
DeepseekVL2Config,
MultiModalityConfig,
KimiVLConfig,
InternVLChatConfig,
Step3VLConfig,
LongcatFlashConfig,
Olmo3Config,
Qwen3NextConfig,
FalconH1Config,
DotsVLMConfig,
DotsOCRConfig,
NemotronHConfig,
DeepseekVLV2Config,
]
_CONFIG_REGISTRY = {
config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY
}
for name, cls in _CONFIG_REGISTRY.items():
......@@ -191,6 +197,11 @@ def get_config(
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if "deepseek-ai/DeepSeek-OCR" in model:
config.model_type = "deepseek-ocr"
# Due to an unknown reason, Hugging Face’s AutoConfig mistakenly recognizes the configuration of deepseek-ocr as deepseekvl2.
# This is a temporary workaround and will require further optimization.
except ValueError as e:
if not "deepseek_v32" in str(e):
raise e
......@@ -213,7 +224,8 @@ def get_config(
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"num_hidden_layers": 26,
# Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"patch_size": 14,
}
config.vision_config = SiglipVisionConfig(**vision_config)
......
......@@ -619,7 +619,6 @@ def popen_launch_server(
start_time = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start_time < timeout:
return_code = process.poll()
if return_code is not None:
# Server failed to start (non-zero exit code) or crashed
......
......@@ -150,6 +150,62 @@ class TestQwen2AudioServer(AudioOpenAITestMixin):
model = "Qwen/Qwen2-Audio-7B-Instruct"
class TestDeepseekOCRServer(TestOpenAIMLLMServerBase):
model = "deepseek-ai/DeepSeek-OCR"
trust_remote_code = False
def verify_single_image_response_for_ocr(self, response):
"""Verify DeepSeek-OCR grounding output with coordinates"""
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# DeepSeek-OCR uses grounding format, outputs coordinates
assert "text" in text.lower(), f"OCR text: {text}, should contain 'text'"
# Verify coordinate format [[x1, y1, x2, y2]]
import re
coord_pattern = r"\[\[[\d\s,]+\]\]"
assert re.search(
coord_pattern, text
), f"OCR text: {text}, should contain coordinate format [[x1, y1, x2, y2]]"
# Verify basic response fields
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
image_url = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/ocr-text.png"
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url},
},
{
"type": "text",
"text": "<|grounding|>Convert the document to markdown.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
self.verify_single_image_response_for_ocr(response)
if __name__ == "__main__":
del (
TestOpenAIMLLMServerBase,
......
......@@ -32,6 +32,7 @@ class TestOpenAIMLLMServerBase(CustomTestCase):
model: str
extra_args: list = []
fixed_args: list = ["--trust-remote-code", "--enable-multimodal"]
trust_remote_code: bool = True
@classmethod
def setUpClass(cls):
......@@ -42,7 +43,11 @@ class TestOpenAIMLLMServerBase(CustomTestCase):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=cls.extra_args + cls.fixed_args,
other_args=(
cls.extra_args + cls.fixed_args + ["--trust-remote-code"]
if cls.trust_remote_code
else []
),
)
cls.base_url += "/v1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment