Unverified Commit 770529a7 authored by Mick's avatar Mick Committed by GitHub
Browse files

model: support deepseek-ocr (#11891)


Co-authored-by: default avataryhyang201 <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avatarShi Shuai <126407087+shuaills@users.noreply.github.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent 39c237f0
from typing import Tuple
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
BASE_SIZE = 1024
IMAGE_SIZE = 640
CROP_MODE = True
MIN_CROPS = 2
MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6.
MAX_CONCURRENCY = 100 # If you have limited GPU memory, lower the concurrency count.
NUM_WORKERS = 64 # image pre-process (resize/padding) workers
PRINT_NUM_VIS_TOKENS = False
SKIP_REPEAT = True
MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
class ImageTransform:
def __init__(
self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
transform_pipelines = [T.ToTensor()]
if normalize:
transform_pipelines.append(T.Normalize(mean, std))
self.transform = T.Compose(transform_pipelines)
def __call__(self, pil_img: Image.Image):
x = self.transform(pil_img)
return x
class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "vit_so400m_patch14_siglip_384.webli"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(
self,
model_name: str = "vit_so400m_patch14_siglip_384.webli",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs,
):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(
self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
n_routed_experts=None,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=None,
topk_group=None,
num_experts_per_tok=None,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
use_mla=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = float(rms_norm_eps)
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_mla = use_mla
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekVLV2Config(PretrainedConfig):
# model_type = "deepseek_vl_v2"
model_type = "deepseek-ocr"
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
def __init__(
self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),),
**kwargs,
):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
self.text_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
class DeepseekOCRConfig(DeepseekV2Config):
model_type = "DeepseekOCR"
...@@ -11,6 +11,8 @@ from transformers import ( ...@@ -11,6 +11,8 @@ from transformers import (
ProcessorMixin, ProcessorMixin,
) )
from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
def select_best_resolution(image_size, candidate_resolutions): def select_best_resolution(image_size, candidate_resolutions):
# used for cropping # used for cropping
...@@ -61,6 +63,7 @@ class DictOutput(object): ...@@ -61,6 +63,7 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput): class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor input_ids: torch.LongTensor
target_ids: torch.LongTensor target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: ( pixel_values: (
torch.Tensor torch.Tensor
) # rename from "images" to "pixel_values" for compatibility ) # rename from "images" to "pixel_values" for compatibility
...@@ -104,6 +107,68 @@ class ImageTransform(object): ...@@ -104,6 +107,68 @@ class ImageTransform(object):
return x return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekVLV2Processor(ProcessorMixin): class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"] attributes = ["tokenizer"]
...@@ -133,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -133,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
self.image_std = image_std self.image_std = image_std
self.normalize = normalize self.normalize = normalize
self.downsample_ratio = downsample_ratio self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform( self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize mean=image_mean, std=image_std, normalize=normalize
) )
...@@ -176,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -176,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
**kwargs, **kwargs,
) )
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1): def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version""" """play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = [] tokenized_data = []
masked_tokenized_data = [] # labels masked_tokenized_data = [] # labels
...@@ -186,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -186,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin):
image_index = 0 image_index = 0
image_token_cnt = messages.count(self.image_token) image_token_cnt = messages.count(self.image_token)
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images( (
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
messages, messages,
pil_images[image_index : image_index + image_token_cnt], pil_images[image_index : image_index + image_token_cnt],
bos=True, bos=True,
eos=True, eos=True,
cropping=len(pil_images) <= 2, cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
) )
image_index = image_token_cnt image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images images_list += images
images_seq_mask += seq_mask images_seq_mask += seq_mask
images_spatial_crop += spatial_crop images_spatial_crop = spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return ( return (
tokenized_data, input_ids,
masked_tokenized_data, masked_tokenized_data,
images_list, images_list,
images_seq_mask, images_seq_mask,
images_spatial_crop, images_spatial_crop,
images_crop,
) )
@property @property
...@@ -251,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -251,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True, inference_mode: bool = True,
system_prompt: str = "", system_prompt: str = "",
max_req_input_len: int = -1, max_req_input_len: int = -1,
cropping: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -274,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -274,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens - num_image_tokens (List[int]): the number of image tokens
""" """
assert ( prompt = conversations or prompt
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
( (
tokenized_str, input_ids,
masked_tokenized_str, masked_tokenized_str,
images_list, images_list,
images_seq_mask, images_seq_mask,
images_spatial_crop, images_spatial_crop,
) = self.format_messages_v2(conversations, images, max_req_input_len) images_crop,
) = self.format_messages_v2(prompt, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0: if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size)) images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else: else:
images = torch.stack(images_list, dim=0) images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack( images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0 [images_spatial_crop], dim=0
...@@ -323,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -323,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
prepare = VLChatProcessorOutput( prepare = VLChatProcessorOutput(
input_ids=input_ids, input_ids=input_ids,
target_ids=target_ids, target_ids=target_ids,
images_crop=images_crop,
pixel_values=images, pixel_values=images,
images_seq_mask=images_seq_mask, images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop, images_spatial_crop=images_spatial_crop,
...@@ -340,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -340,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True, inference_mode: bool = True,
system_prompt: str = "", system_prompt: str = "",
max_req_input_len: int = -1, max_req_input_len: int = -1,
text: list[str] = None,
**kwargs, **kwargs,
): ):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one( prepare = self.process_one(
prompt=prompt, prompt=prompt or text,
conversations=conversations, conversations=conversations,
images=images, images=images,
apply_sft_format=apply_sft_format, apply_sft_format=apply_sft_format,
...@@ -368,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -368,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin):
bos: bool = True, bos: bool = True,
eos: bool = True, eos: bool = True,
cropping: bool = True, cropping: bool = True,
max_req_input_len: int = -1,
): ):
"""Tokenize text with <image> tags.""" """Tokenize text with <image> tags."""
images_list, images_seq_mask, images_spatial_crop = [], [], []
conversation = conversation
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token) text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = [] tokenized_str = []
for text_sep, image in zip(text_splits, images): for text_sep, image in zip(text_splits, images):
"""encode text_sep""" """encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False) tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep) images_seq_mask += [False] * len(tokenized_sep)
"""select best resolution for anyres""" image_shapes.append(image.size)
if cropping:
best_width, best_height = select_best_resolution( if image.size[0] <= 640 and image.size[1] <= 640:
image.size, self.candidate_resolutions crop_ratio = [1, 1]
)
else: else:
best_width, best_height = self.image_size, self.image_size if cropping:
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
"""process the global view""" """process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad( global_view = ImageOps.pad(
image, image,
(self.image_size, self.image_size), (self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean), color=tuple(int(x * 255) for x in self.image_transform.mean),
) )
images_list.append(self.image_transform(global_view)) images_list.append(self.image_transform(global_view))
"""process the local views""" num_width_tiles, num_height_tiles = crop_ratio
local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
"""record height / width crop num"""
num_width_tiles, num_height_tiles = (
best_width // self.image_size,
best_height // self.image_size,
)
images_spatial_crop.append([num_width_tiles, num_height_tiles]) images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
"""add image tokens""" """add image tokens"""
h = w = math.ceil( num_queries = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio (self.image_size // self.patch_size) / self.downsample_ratio
) )
# global views tokens h * (w + 1), 1 is for line separator num_queries_base = math.ceil(
tokenized_image = [self.image_token_id] * h * (w + 1) (self.base_size // self.patch_size) / self.downsample_ratio
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
) )
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image) images_seq_mask += [True] * len(tokenized_image)
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens num_image_tokens.append(len(tokenized_image))
"""process the last text split""" """process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep) images_seq_mask += [False] * len(tokenized_sep)
...@@ -462,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -462,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin):
images_seq_mask images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return tokenized_str, images_list, images_seq_mask, images_spatial_crop masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
class DeepseekVL2VisionEncoderConfig(PretrainedConfig): class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
...@@ -547,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig): ...@@ -547,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2" model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
......
...@@ -921,6 +921,7 @@ multimodal_model_archs = [ ...@@ -921,6 +921,7 @@ multimodal_model_archs = [
"DotsVLMForCausalLM", "DotsVLMForCausalLM",
"DotsOCRForCausalLM", "DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM", "Sarashina2VisionForCausalLM",
"DeepseekOCRForCausalLM",
] ]
......
...@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], ...@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS: if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_arch(model_config, architectures)
return ModelRegistry.resolve_model_cls(architectures) return ModelRegistry.resolve_model_cls(architectures)
......
# Copyright 2025 The SwissAI Initiative
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only Apertus model compatible with HuggingFace weights."""
import copy
import logging
import math
from functools import partial
from typing import Iterable, List, Optional, Set, Tuple, Type, TypeAlias, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers.models.vitdet.modeling_vitdet import get_rel_pos
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek import DeepseekForCausalLM
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM
from sglang.srt.models.transformers import maybe_prefix
NestedTensors: TypeAlias = Union[
list["NestedTensors"],
list["torch.Tensor"],
"torch.Tensor",
tuple["torch.Tensor", ...],
]
MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...]
logger = logging.getLogger(__name__)
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
Note:
This updates `inputs_embeds` in place.
"""
if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(test_elements_list, pin_memory=True).to(
device=elements.device, non_blocking=True
)
return torch.isin(elements, test_elements)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int | list[int],
) -> torch.Tensor:
"""
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
`placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates `inputs_embeds` in place.
"""
if isinstance(placeholder_token_id, list):
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = input_ids == placeholder_token_id
return _merge_multimodal_embeddings(
inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
class MlpProjector(nn.Module):
def __init__(
self,
projector_type,
input_dim,
n_embed,
depth=1,
mlp_ratio=1,
downsample_ratio=4,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.token_pooling = False
self.conv_fusion_high_low_features = False
super().__init__()
if projector_type == "identity":
modules = nn.Identity()
elif projector_type == "linear":
modules = nn.Linear(input_dim, n_embed)
elif projector_type == "mlp_gelu":
mlp_depth = depth
modules = [nn.Linear(input_dim, n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "normlayer_downsample_mlp_gelu":
mlp_depth = depth
mlp_ratio = mlp_ratio
modules = [
nn.LayerNorm(input_dim * downsample_ratio * downsample_ratio),
nn.Linear(
input_dim * downsample_ratio * downsample_ratio,
n_embed * mlp_ratio,
),
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "downsample_mlp_gelu":
mlp_depth = depth
mlp_ratio = mlp_ratio
modules = [
nn.Linear(
input_dim * downsample_ratio * downsample_ratio,
n_embed * mlp_ratio,
)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = depth
self.high_up_proj = nn.Linear(input_dim, n_embed // 2)
self.low_up_proj = nn.Linear(input_dim, n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "hybrid_split_feature_mlp_gelu":
mlp_depth = depth
channel_div = 0.5
self.high_up_proj = nn.Linear(input_dim[0], int(n_embed * channel_div))
self.low_up_proj = nn.Linear(
input_dim[1], n_embed - int(n_embed * channel_div)
)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "low_high_split_mlp_gelu":
mlp_depth = depth
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed // 2, n_embed // 2))
modules = nn.Sequential(*modules)
self.high_layers = nn.Sequential(*modules)
self.low_layers = copy.deepcopy(modules)
else:
raise ValueError(f"Unknown projector type: {projector_type}")
self.layers = modules
def forward(self, x):
if self.token_pooling:
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# Concatenate on channel dimension
patches = patches.contiguous().view(
batch_size, channels, h_patches * w_patches, -1
)
# Pass through linear layer
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
if self.conv_fusion_high_low_features:
x = self.fusion_layer(x[:, 0]) + x[:, 1]
if self.projector_type == "low_high_hybrid_split_mlp_gelu":
high_x, low_x = x[0], x[1]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.projector_type == "hybrid_split_feature_mlp_gelu":
high_x = x[..., : self.input_dim[0]]
low_x = x[..., self.input_dim[0] :]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.projector_type == "low_high_split_mlp_gelu":
high_x, low_x = x[0], x[1]
high_x = self.high_layers(high_x)
low_x = self.low_layers(low_x)
x = torch.concat([high_x, low_x], dim=-1)
return x
if (
self.projector_type == "downsample_mlp_gelu"
or self.projector_type == "normlayer_downsample_mlp_gelu"
):
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.downsample_ratio:
pad = self.downsample_ratio - h % self.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.downsample_ratio,
stride=self.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(
q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(
B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)
)
rel_w = rel_w.view(
B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)
)
attn_bias = (rel_h + rel_w).view(
B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias
)
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = (
x.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)
return x
def window_partition(
x: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor,
window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int],
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def get_abs_pos_sam(abs_pos, tgt_size):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(
1, img_size // patch_size, img_size // patch_size, embed_dim
)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x2 = self.net_2(x)
x3 = self.net_3(x2.clone())
return x3
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_encoder = ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
image_encoder.eval()
if checkpoint is not None:
state_dict = torch.load(checkpoint)
image_encoder.load_state_dict(
{k[30:]: v for k, v in state_dict.items() if "vision_tower_high" in k},
strict=True,
)
return image_encoder
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = (
old_pos_embed.view(1, src_size, src_size, dim)
.permute(0, 3, 1, 2)
.contiguous()
)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
if patch_embeds is not None:
patch_embeds = patch_embeds
else:
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + get_abs_pos(
self.position_embedding(self.position_ids), embeddings.size(1)
)
return embeddings
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg["num_attention_heads"]
self.n_local_heads = cfg["num_attention_heads"]
self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
self.max_seq_len = cfg["seq_length"]
self.use_flash_attention = cfg["use_flash_attn"]
self.qkv_proj = torch.nn.Linear(
cfg["hidden_size"], cfg["hidden_size"] * 3, bias=True
)
self.out_proj = torch.nn.Linear(
cfg["hidden_size"], cfg["hidden_size"], bias=True
)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg["attention_dropout"]
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# (B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
output = torch.nn.functional.scaled_dot_product_attention(
xq, xk, xv, attn_mask=None
)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
else:
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
output = torch.nn.functional.scaled_dot_product_attention(
xq, xk, xv, attn_mask=None
)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg["num_attention_heads"]
self.dim = cfg["hidden_size"]
self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg["hidden_size"], hidden_dim=cfg["ffn_hidden_size"]
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.num_layers = cfg["num_layers"]
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class VitModel(nn.Module):
def __init__(self, cfg, freeze_embed=False, freeze_pre_norm=False) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(
hidden_size=cfg["hidden_size"],
image_size=cfg["image_size"],
patch_size=cfg["patch_size"],
)
if freeze_embed:
for _, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg["hidden_size"],
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg["hidden_size"],
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
if freeze_pre_norm:
for _, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
@property
def dtype(self):
return next(self.parameters()).dtype
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(self, x, patch_embeds):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
output = self.transformer(hidden_states)
return output
vit_model_cfg = dict(
num_layers=24,
hidden_size=1024,
num_heads=16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout=0.0,
attention_dropout=0.0,
no_persist_layer_norm=False,
layernorm_epsilon=1e-5,
pre_layernorm_epsilon=1e-5,
image_size=224,
patch_size=14,
recompute_list=[],
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
class DeepseekOCRForCausalLM(nn.Module):
def __init__(
self,
*,
config: DeepseekVLV2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
n_embed = 1280
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
if self.text_config.topk_method == "noaux_tc":
self.model = DeepseekV3ForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
elif not self.text_config.use_mla:
self.model = DeepseekForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
else:
self.model = DeepseekV2ForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
self.sam_model = build_sam_vit_b()
self.vision_model = build_clip_l()
n_embed = 1280
self.projector = MlpProjector(
projector_type="linear",
input_dim=2048,
n_embed=n_embed,
)
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image crop. " f"Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
# split the pixel and image_crop, all batch_size = 1
images_in_this_batch = []
with torch.no_grad():
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
if torch.sum(patches).item() != 0:
local_features_1 = self.sam_model(patches)
local_features_2 = self.vision_model(patches, local_features_1)
local_features = torch.cat(
(
local_features_2[:, 1:],
local_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
local_features = self.projector(local_features)
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
_2, hw2, n_dim2 = local_features.shape
h2 = w2 = int(hw2**0.5)
width_crop_num, height_crop_num = int(crop_shape[0]), int(
crop_shape[1]
)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
local_features = (
local_features.view(
height_crop_num, width_crop_num, h2, w2, n_dim2
)
.permute(0, 2, 1, 3, 4)
.reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
)
local_features = torch.cat(
[
local_features,
self.image_newline[None, None, :].expand(
height_crop_num * h2, 1, n_dim2
),
],
dim=1,
)
local_features = local_features.view(-1, n_dim2)
global_local_features = torch.cat(
[local_features, global_features, self.view_seperator[None, :]],
dim=0,
)
else:
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
global_local_features = torch.cat(
[global_features, self.view_seperator[None, :]], dim=0
)
images_in_this_batch.append(global_local_features)
return images_in_this_batch
def _process_image_input(self, mm_items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.stack([item.feature for item in mm_items], dim=0).type(
self.vision_model.dtype
)
images_crop = (
torch.stack([item.images_crop for item in mm_items], dim=0)
.type(torch.long)
.to(device=pixel_values.device)
)
images_spatial_crop = (
torch.cat([item.images_spatial_crop for item in mm_items], dim=0)
.type(torch.long)
.to(device=pixel_values.device)
)
assert images_crop.dim() == 6
assert images_spatial_crop.dim() == 3
vision_feature_lists = self._pixel_values_to_embedding(
pixel_values=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
)
vision_features = torch.cat(vision_feature_lists, dim=0).type(
self.vision_model.dtype
)
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id
)
return inputs_embeds
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
vision_embeddings = self._process_image_input(items)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
name = "model.lm_head.weight"
elif name.startswith("model."):
if (
"image_newline" in name
or ".projector" in name
or "vision_model" in name
or "sam_model" in name
or "view_seperator" in name
):
name = name[len("model.") :]
elif not (
".projector" in name
or "vision_model" in name
or "sam_model" in name
or "image_newline" in name
):
name = name.replace("model.", "model.model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
EntryClass = [DeepseekOCRForCausalLM]
...@@ -200,7 +200,6 @@ _is_flashinfer_available = is_flashinfer_available() ...@@ -200,7 +200,6 @@ _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9() _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -178,6 +178,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -178,6 +178,7 @@ class BaseMultimodalProcessor(ABC):
"image_attention_mask": Modality.IMAGE, "image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE,
"images_spatial_crop": Modality.IMAGE, "images_spatial_crop": Modality.IMAGE,
"images_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE, "tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE, "image_grid_hws": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE, "aspect_ratio_ids": Modality.IMAGE,
......
from typing import List, Union
from sglang.srt.models.deepseek_ocr import DeepseekOCRForCausalLM
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class DeepseekOCRProcessor(BaseMultimodalProcessor):
models = [DeepseekOCRForCausalLM]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
_processor.image_size = 640
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.mm_tokens,
image_data=image_data,
)
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.mm_tokens.image_token_id,
}
...@@ -838,6 +838,19 @@ register_conv_template( ...@@ -838,6 +838,19 @@ register_conv_template(
) )
) )
register_conv_template(
Conversation(
name="deepseek-ocr",
system_message="",
system_template="",
roles=("", ""),
sep="",
sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str=["<|end▁of▁sentence|>"],
image_token="<image>",
)
)
register_conv_template( register_conv_template(
Conversation( Conversation(
name="deepseek-vl2", name="deepseek-vl2",
...@@ -981,6 +994,7 @@ MODEL_TYPE_TO_TEMPLATE = { ...@@ -981,6 +994,7 @@ MODEL_TYPE_TO_TEMPLATE = {
"phi4mm": "phi-4-mm", "phi4mm": "phi-4-mm",
"minicpmv": "minicpmv", "minicpmv": "minicpmv",
"minicpmo": "minicpmo", "minicpmo": "minicpmo",
"deepseek-ocr": "deepseek-ocr",
} }
...@@ -1057,3 +1071,11 @@ def match_phi_4_mm(model_path: str): ...@@ -1057,3 +1071,11 @@ def match_phi_4_mm(model_path: str):
return "phi-4-mm" return "phi-4-mm"
model_type = get_model_type(model_path) model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type) return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
def match_deepseek_ocr(model_path: str):
if "deepseek-ocr" in model_path.lower():
return "deepseek-ocr"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import tempfile import tempfile
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, List, Optional, Type, Union
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -51,26 +51,32 @@ from sglang.srt.configs import ( ...@@ -51,26 +51,32 @@ from sglang.srt.configs import (
Qwen3NextConfig, Qwen3NextConfig,
Step3VLConfig, Step3VLConfig,
) )
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig, DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig,
DeepseekVL2Config.model_type: DeepseekVL2Config, DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig, MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig, KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig, InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig, Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig, LongcatFlashConfig,
Olmo3Config.model_type: Olmo3Config, Olmo3Config,
Qwen3NextConfig.model_type: Qwen3NextConfig, Qwen3NextConfig,
FalconH1Config.model_type: FalconH1Config, FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig, DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig, DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig, NemotronHConfig,
DeepseekVLV2Config,
]
_CONFIG_REGISTRY = {
config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
...@@ -191,6 +197,11 @@ def get_config( ...@@ -191,6 +197,11 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
if "deepseek-ai/DeepSeek-OCR" in model:
config.model_type = "deepseek-ocr"
# Due to an unknown reason, Hugging Face’s AutoConfig mistakenly recognizes the configuration of deepseek-ocr as deepseekvl2.
# This is a temporary workaround and will require further optimization.
except ValueError as e: except ValueError as e:
if not "deepseek_v32" in str(e): if not "deepseek_v32" in str(e):
raise e raise e
...@@ -213,7 +224,8 @@ def get_config( ...@@ -213,7 +224,8 @@ def get_config(
"intermediate_size": 4304, "intermediate_size": 4304,
"model_type": "siglip_vision_model", "model_type": "siglip_vision_model",
"num_attention_heads": 16, "num_attention_heads": 16,
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction. "num_hidden_layers": 26,
# Model is originally 27-layer, we only need the first 26 layers for feature extraction.
"patch_size": 14, "patch_size": 14,
} }
config.vision_config = SiglipVisionConfig(**vision_config) config.vision_config = SiglipVisionConfig(**vision_config)
......
...@@ -619,7 +619,6 @@ def popen_launch_server( ...@@ -619,7 +619,6 @@ def popen_launch_server(
start_time = time.perf_counter() start_time = time.perf_counter()
with requests.Session() as session: with requests.Session() as session:
while time.perf_counter() - start_time < timeout: while time.perf_counter() - start_time < timeout:
return_code = process.poll() return_code = process.poll()
if return_code is not None: if return_code is not None:
# Server failed to start (non-zero exit code) or crashed # Server failed to start (non-zero exit code) or crashed
......
...@@ -150,6 +150,62 @@ class TestQwen2AudioServer(AudioOpenAITestMixin): ...@@ -150,6 +150,62 @@ class TestQwen2AudioServer(AudioOpenAITestMixin):
model = "Qwen/Qwen2-Audio-7B-Instruct" model = "Qwen/Qwen2-Audio-7B-Instruct"
class TestDeepseekOCRServer(TestOpenAIMLLMServerBase):
model = "deepseek-ai/DeepSeek-OCR"
trust_remote_code = False
def verify_single_image_response_for_ocr(self, response):
"""Verify DeepSeek-OCR grounding output with coordinates"""
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
# DeepSeek-OCR uses grounding format, outputs coordinates
assert "text" in text.lower(), f"OCR text: {text}, should contain 'text'"
# Verify coordinate format [[x1, y1, x2, y2]]
import re
coord_pattern = r"\[\[[\d\s,]+\]\]"
assert re.search(
coord_pattern, text
), f"OCR text: {text}, should contain coordinate format [[x1, y1, x2, y2]]"
# Verify basic response fields
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
image_url = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/ocr-text.png"
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url},
},
{
"type": "text",
"text": "<|grounding|>Convert the document to markdown.",
},
],
},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
self.verify_single_image_response_for_ocr(response)
if __name__ == "__main__": if __name__ == "__main__":
del ( del (
TestOpenAIMLLMServerBase, TestOpenAIMLLMServerBase,
......
...@@ -32,6 +32,7 @@ class TestOpenAIMLLMServerBase(CustomTestCase): ...@@ -32,6 +32,7 @@ class TestOpenAIMLLMServerBase(CustomTestCase):
model: str model: str
extra_args: list = [] extra_args: list = []
fixed_args: list = ["--trust-remote-code", "--enable-multimodal"] fixed_args: list = ["--trust-remote-code", "--enable-multimodal"]
trust_remote_code: bool = True
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -42,7 +43,11 @@ class TestOpenAIMLLMServerBase(CustomTestCase): ...@@ -42,7 +43,11 @@ class TestOpenAIMLLMServerBase(CustomTestCase):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=cls.extra_args + cls.fixed_args, other_args=(
cls.extra_args + cls.fixed_args + ["--trust-remote-code"]
if cls.trust_remote_code
else []
),
) )
cls.base_url += "/v1" cls.base_url += "/v1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment