Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb367acf
Unverified
Commit
fb367acf
authored
Oct 01, 2025
by
qrskannbara
Committed by
GitHub
Sep 30, 2025
Browse files
Support Dots.ocr model (#11071)
parent
a6cc86df
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
244 additions
and
3 deletions
+244
-3
python/sglang/srt/configs/__init__.py
python/sglang/srt/configs/__init__.py
+2
-0
python/sglang/srt/configs/dots_ocr.py
python/sglang/srt/configs/dots_ocr.py
+64
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-0
python/sglang/srt/models/dots_ocr.py
python/sglang/srt/models/dots_ocr.py
+173
-0
python/sglang/srt/multimodal/processors/dots_vlm.py
python/sglang/srt/multimodal/processors/dots_vlm.py
+2
-3
No files found.
python/sglang/srt/configs/__init__.py
View file @
fb367acf
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.chatglm
import
ChatGLMConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.dbrx
import
DbrxConfig
from
sglang.srt.configs.deepseekvl2
import
DeepseekVL2Config
from
sglang.srt.configs.deepseekvl2
import
DeepseekVL2Config
from
sglang.srt.configs.dots_ocr
import
DotsOCRConfig
from
sglang.srt.configs.dots_vlm
import
DotsVLMConfig
from
sglang.srt.configs.dots_vlm
import
DotsVLMConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.exaone
import
ExaoneConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
from
sglang.srt.configs.janus_pro
import
MultiModalityConfig
...
@@ -28,4 +29,5 @@ __all__ = [
...
@@ -28,4 +29,5 @@ __all__ = [
"Step3VisionEncoderConfig"
,
"Step3VisionEncoderConfig"
,
"Qwen3NextConfig"
,
"Qwen3NextConfig"
,
"DotsVLMConfig"
,
"DotsVLMConfig"
,
"DotsOCRConfig"
,
]
]
python/sglang/srt/configs/dots_ocr.py
0 → 100644
View file @
fb367acf
from
typing
import
Optional
from
transformers
import
AutoProcessor
,
Qwen2_5_VLProcessor
from
transformers.image_processing_utils
import
BaseImageProcessor
from
transformers.models.qwen2
import
Qwen2Config
from
sglang.srt.configs.dots_vlm
import
DotsVisionConfig
class
DotsOCRConfig
(
Qwen2Config
):
model_type
=
"dots_ocr"
def
__init__
(
self
,
image_token_id
=
151665
,
video_token_id
=
151656
,
vision_config
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
image_token_id
=
image_token_id
self
.
video_token_id
=
video_token_id
self
.
vision_config
=
DotsVisionConfig
(
**
(
vision_config
or
{}))
def
save_pretrained
(
self
,
save_directory
,
**
kwargs
):
self
.
_auto_class
=
None
super
().
save_pretrained
(
save_directory
,
**
kwargs
)
class
DummyVideoProcessor
(
BaseImageProcessor
):
model_input_names
=
[
"pixel_values"
]
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
None
class
DotsVLProcessor
(
Qwen2_5_VLProcessor
):
def
__init__
(
self
,
image_processor
=
None
,
tokenizer
=
None
,
video_processor
=
None
,
chat_template
=
None
,
**
kwargs
):
if
video_processor
is
None
:
video_processor
=
DummyVideoProcessor
()
super
().
__init__
(
image_processor
,
tokenizer
,
video_processor
,
chat_template
=
chat_template
)
self
.
image_token
=
(
"<|imgpad|>"
if
not
hasattr
(
tokenizer
,
"image_token"
)
else
tokenizer
.
image_token
)
self
.
image_token_id
=
(
tokenizer
.
image_token_id
if
getattr
(
tokenizer
,
"image_token_id"
,
None
)
is
not
None
else
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
)
AutoProcessor
.
register
(
DotsOCRConfig
,
DotsVLProcessor
)
python/sglang/srt/configs/model_config.py
View file @
fb367acf
...
@@ -778,6 +778,7 @@ multimodal_model_archs = [
...
@@ -778,6 +778,7 @@ multimodal_model_archs = [
"VILAForConditionalGeneration"
,
"VILAForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
"DotsVLMForCausalLM"
,
"DotsVLMForCausalLM"
,
"DotsOCRForCausalLM"
,
"Sarashina2VisionForCausalLM"
,
"Sarashina2VisionForCausalLM"
,
]
]
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
fb367acf
...
@@ -38,6 +38,7 @@ from sglang.srt.configs import (
...
@@ -38,6 +38,7 @@ from sglang.srt.configs import (
ChatGLMConfig
,
ChatGLMConfig
,
DbrxConfig
,
DbrxConfig
,
DeepseekVL2Config
,
DeepseekVL2Config
,
DotsOCRConfig
,
DotsVLMConfig
,
DotsVLMConfig
,
ExaoneConfig
,
ExaoneConfig
,
KimiVLConfig
,
KimiVLConfig
,
...
@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
LongcatFlashConfig
.
model_type
:
LongcatFlashConfig
,
LongcatFlashConfig
.
model_type
:
LongcatFlashConfig
,
Qwen3NextConfig
.
model_type
:
Qwen3NextConfig
,
Qwen3NextConfig
.
model_type
:
Qwen3NextConfig
,
DotsVLMConfig
.
model_type
:
DotsVLMConfig
,
DotsVLMConfig
.
model_type
:
DotsVLMConfig
,
DotsOCRConfig
.
model_type
:
DotsOCRConfig
,
}
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
...
...
python/sglang/srt/models/dots_ocr.py
0 → 100644
View file @
fb367acf
# coding=utf-8
# Adapted from Qwen2.5-VL SGLang implementation
import
logging
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers.activations
import
ACT2FN
from
sglang.srt.configs
import
DotsOCRConfig
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.dots_vlm_vit
import
DotsVisionTransformer
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
class
DotsOCRForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
DotsOCRConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
# Initialize vision transformer
self
.
visual
=
DotsVisionTransformer
(
config
.
vision_config
,
)
# Initialize language model
self
.
model
=
Qwen2ForCausalLM
(
config
,
quant_config
)
# Initialize LM head
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# Extract pixel values and grid information (following reference pattern)
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thw
=
torch
.
concat
(
[
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
).
to
(
self
.
visual
.
device
)
# Add dimension checks like in reference code
assert
pixel_values
.
dim
()
==
2
,
f
"
{
pixel_values
.
dim
()
=
}
"
assert
image_grid_thw
.
dim
()
==
2
,
f
"
{
image_grid_thw
.
dim
()
=
}
"
# Process through vision tower
image_embeds
=
self
.
visual
(
pixel_values
,
image_grid_thw
)
# Ensure consistent dtype for FlashInfer compatibility
# Force bfloat16 to match model's expected dtype
if
hasattr
(
self
.
model
,
"embed_tokens"
):
target_dtype
=
self
.
model
.
embed_tokens
.
weight
.
dtype
if
image_embeds
.
dtype
!=
target_dtype
:
image_embeds
=
image_embeds
.
to
(
target_dtype
)
return
image_embeds
def
_pad_vit_attn_dummy_heads
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads
=
self
.
config
.
vision_config
.
num_dummy_heads
if
num_dummy_heads
==
0
:
return
loaded_weight
head_dim
=
self
.
config
.
vision_config
.
head_dim
if
"attn.qkv_proj"
in
name
:
wq
,
wk
,
wv
=
loaded_weight
.
chunk
(
3
,
dim
=
0
)
if
name
.
endswith
(
".weight"
):
dummy_shape
=
[
num_dummy_heads
,
head_dim
,
wq
.
shape
[
-
1
]]
elif
name
.
endswith
(
".bias"
):
dummy_shape
=
[
num_dummy_heads
,
head_dim
]
else
:
raise
RuntimeError
(
f
"Unsupported weight with name=
{
name
}
"
)
pad_func
=
lambda
x
:
torch
.
cat
(
[
x
.
unflatten
(
0
,
(
-
1
,
head_dim
)),
x
.
new_zeros
(
dummy_shape
)],
dim
=
0
).
flatten
(
0
,
1
)
wq
,
wk
,
wv
=
pad_func
(
wq
),
pad_func
(
wk
),
pad_func
(
wv
)
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
if
"attn.proj.weight"
in
name
:
padded_weight
=
loaded_weight
.
new_zeros
(
loaded_weight
.
shape
[
0
],
head_dim
*
num_dummy_heads
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
padded_weight
],
dim
=-
1
)
if
"attn.q_norm.weight"
in
name
or
"attn.k_norm.weight"
in
name
:
padded_weight
=
loaded_weight
.
new_zeros
(
head_dim
*
num_dummy_heads
)
loaded_weight
=
torch
.
cat
([
loaded_weight
,
padded_weight
],
dim
=
0
)
return
loaded_weight
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
multimodal_model
=
self
,
language_model
=
self
.
model
,
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
"""Load weights for the model, separating vision and language weights"""
weights
=
list
(
weights
)
# Separate vision tower weights and language model weights
vision_weights
=
[]
language_weights
=
[]
for
name
,
loaded_weight
in
weights
:
if
name
.
startswith
(
"vision_tower."
):
vision_name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
vision_weights
.
append
((
vision_name
,
loaded_weight
))
else
:
# All other weights go to language model
language_weights
.
append
((
name
,
loaded_weight
))
# Load vision tower weights
vision_state_dict
=
dict
(
vision_weights
)
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
vision_state_dict
.
items
():
name
=
name
.
replace
(
"vision_tower"
,
"visual"
)
if
name
not
in
params_dict
:
raise
ValueError
(
f
"Weight
{
name
}
not found in params_dict"
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
self
.
_pad_vit_attn_dummy_heads
(
name
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
language_weights
:
self
.
model
.
load_weights
(
language_weights
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
EntryClass
=
[
DotsOCRForCausalLM
]
python/sglang/srt/multimodal/processors/dots_vlm.py
View file @
fb367acf
...
@@ -5,6 +5,7 @@ from typing import Dict, List, Union
...
@@ -5,6 +5,7 @@ from typing import Dict, List, Union
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.models.dots_ocr
import
DotsOCRForCausalLM
from
sglang.srt.models.dots_vlm
import
DotsVLMForCausalLM
from
sglang.srt.models.dots_vlm
import
DotsVLMForCausalLM
from
sglang.srt.multimodal.processors.base_processor
import
(
from
sglang.srt.multimodal.processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
...
@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
...
@@ -14,7 +15,7 @@ from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
class
DotsVLMImageProcessor
(
BaseMultimodalProcessor
):
class
DotsVLMImageProcessor
(
BaseMultimodalProcessor
):
models
=
[
DotsVLMForCausalLM
]
models
=
[
DotsVLMForCausalLM
,
DotsOCRForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
...
@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
...
@@ -82,11 +83,9 @@ class DotsVLMImageProcessor(BaseMultimodalProcessor):
for
image
in
base_output
.
images
for
image
in
base_output
.
images
]
]
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
base_output
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
combined_mm_item
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
combined_mm_item
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
,
self
.
mm_tokens
base_output
,
self
.
mm_tokens
)
)
if
combined_mm_item
is
None
:
if
combined_mm_item
is
None
:
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment