Unverified Commit fb367acf authored by qrskannbara's avatar qrskannbara Committed by GitHub
Browse files

Support Dots.ocr model (#11071)

parent a6cc86df
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.dots_ocr import DotsOCRConfig
from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
......@@ -28,4 +29,5 @@ __all__ = [
"Step3VisionEncoderConfig",
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
]
from typing import Optional
from transformers import AutoProcessor, Qwen2_5_VLProcessor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.models.qwen2 import Qwen2Config
from sglang.srt.configs.dots_vlm import DotsVisionConfig
class DotsOCRConfig(Qwen2Config):
model_type = "dots_ocr"
def __init__(
self,
image_token_id=151665,
video_token_id=151656,
vision_config: Optional[dict] = None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.vision_config = DotsVisionConfig(**(vision_config or {}))
def save_pretrained(self, save_directory, **kwargs):
self._auto_class = None
super().save_pretrained(save_directory, **kwargs)
class DummyVideoProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __call__(self, *args, **kwargs):
return None
class DotsVLProcessor(Qwen2_5_VLProcessor):
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs
):
if video_processor is None:
video_processor = DummyVideoProcessor()
super().__init__(
image_processor, tokenizer, video_processor, chat_template=chat_template
)
self.image_token = (
"<|imgpad|>"
if not hasattr(tokenizer, "image_token")
else tokenizer.image_token
)
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None) is not None
else tokenizer.convert_tokens_to_ids(self.image_token)
)
AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)
......@@ -778,6 +778,7 @@ multimodal_model_archs = [
"VILAForConditionalGeneration",
"Step3VLForConditionalGeneration",
"DotsVLMForCausalLM",
"DotsOCRForCausalLM",
"Sarashina2VisionForCausalLM",
]
......
......@@ -38,6 +38,7 @@ from sglang.srt.configs import (
ChatGLMConfig,
DbrxConfig,
DeepseekVL2Config,
DotsOCRConfig,
DotsVLMConfig,
ExaoneConfig,
KimiVLConfig,
......@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
LongcatFlashConfig.model_type: LongcatFlashConfig,
Qwen3NextConfig.model_type: Qwen3NextConfig,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
# coding=utf-8
# Adapted from Qwen2.5-VL SGLang implementation
import logging
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from sglang.srt.configs import DotsOCRConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
class DotsOCRForCausalLM(nn.Module):
def __init__(
self,
config: DotsOCRConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
# Initialize vision transformer
self.visual = DotsVisionTransformer(
config.vision_config,
)
# Initialize language model
self.model = Qwen2ForCausalLM(config, quant_config)
# Initialize LM head
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Extract pixel values and grid information (following reference pattern)
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thw = torch.concat(
[item.image_grid_thw for item in items], dim=0
).to(self.visual.device)
# Add dimension checks like in reference code
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
# Process through vision tower
image_embeds = self.visual(pixel_values, image_grid_thw)
# Ensure consistent dtype for FlashInfer compatibility
# Force bfloat16 to match model's expected dtype
if hasattr(self.model, "embed_tokens"):
target_dtype = self.model.embed_tokens.weight.dtype
if image_embeds.dtype != target_dtype:
image_embeds = image_embeds.to(target_dtype)
return image_embeds
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if "attn.qkv_proj" in name:
wq, wk, wv = loaded_weight.chunk(3, dim=0)
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
pad_func = lambda x: torch.cat(
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
).flatten(0, 1)
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
loaded_weight = torch.cat([wq, wk, wv], dim=0)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
) -> torch.Tensor:
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
multimodal_model=self,
language_model=self.model,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""Load weights for the model, separating vision and language weights"""
weights = list(weights)
# Separate vision tower weights and language model weights
vision_weights = []
language_weights = []
for name, loaded_weight in weights:
if name.startswith("vision_tower."):
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
vision_weights.append((vision_name, loaded_weight))
else:
# All other weights go to language model
language_weights.append((name, loaded_weight))
# Load vision tower weights
vision_state_dict = dict(vision_weights)
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in vision_state_dict.items():
name = name.replace("vision_tower", "visual")
if name not in params_dict:
raise ValueError(f"Weight {name} not found in params_dict")
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
weight_loader(param, loaded_weight)
if language_weights:
self.model.load_weights(language_weights)
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
EntryClass = [DotsOCRForCausalLM]
......@@ -5,6 +5,7 @@ from typing import Dict, List, Union
from PIL import Image
from sglang.srt.models.dots_ocr import DotsOCRForCausalLM
from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
......@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
class DotsVLMImageProcessor(BaseMultimodalProcessor):
models = [DotsVLMForCausalLM]
models = [DotsVLMForCausalLM, DotsOCRForCausalLM]
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
......@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
for image in base_output.images
]
base_output.images = await asyncio.gather(*resize_tasks)
combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
if combined_mm_item is None:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment