Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c1815a99
Unverified
Commit
c1815a99
authored
Sep 18, 2025
by
Chang Su
Committed by
GitHub
Sep 18, 2025
Browse files
model support: Sarashina2VisionForCausalLM (#10632)
parent
4e6c4923
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
366 additions
and
2 deletions
+366
-2
examples/chat_template/vision_template_sarashina_vl.jinja
examples/chat_template/vision_template_sarashina_vl.jinja
+9
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+2
-2
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+4
-0
python/sglang/srt/models/sarashina2_vision.py
python/sglang/srt/models/sarashina2_vision.py
+269
-0
python/sglang/srt/multimodal/processors/sarashina2_vision.py
python/sglang/srt/multimodal/processors/sarashina2_vision.py
+81
-0
No files found.
examples/chat_template/vision_template_sarashina_vl.jinja
0 → 100644
View file @
c1815a99
{#
In sglang, the default chat templates often assume message['content'] is a plain string.
That works fine for simple text conversations, but it ignores multimodal inputs (e.g. image_url, tool_call).
To align with the original model behavior and support richer content,
we iterate over message['content'] as a list of typed items and extract their values directly.
This way, both text and non-text inputs are preserved in the prompt.
Original template: https://huggingface.co/sbintuitions/sarashina2-vision-8b?chat_template=default
#}
{{ bos_token + '<|prefix|><|file|><|suffix|>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.\n\n' }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Human: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% elif message['role'] == 'assistant' %}{{ '### Assistant: ' }}{%- if message['content'] is string %}{{ message['content'] }}{%- else %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% endif %}{% endfor %}{% endif %}{{ '\n' }}{% endif %}{% endfor %}{% if messages[-1]['role'] == 'user' %}{{ '### Assistant:' }}{% endif %}
python/sglang/srt/configs/model_config.py
View file @
c1815a99
...
@@ -756,6 +756,7 @@ multimodal_model_archs = [
...
@@ -756,6 +756,7 @@ multimodal_model_archs = [
"VILAForConditionalGeneration"
,
"VILAForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
"Step3VLForConditionalGeneration"
,
"DotsVLMForCausalLM"
,
"DotsVLMForCausalLM"
,
"Sarashina2VisionForCausalLM"
,
]
]
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
c1815a99
...
@@ -374,8 +374,8 @@ def get_processor(
...
@@ -374,8 +374,8 @@ def get_processor(
**
kwargs
,
**
kwargs
,
)
)
# fix: for Qwen2-VL model, inject default 'size' if not provided.
# fix: for Qwen2-VL
and Sarashina2Vision
model
s
, inject default 'size' if not provided.
if
config
.
model_type
in
{
"qwen2_vl"
}:
if
config
.
model_type
in
{
"qwen2_vl"
,
"sarashina2_vision"
}:
if
"size"
not
in
kwargs
:
if
"size"
not
in
kwargs
:
kwargs
[
"size"
]
=
{
"shortest_edge"
:
3136
,
"longest_edge"
:
1003520
}
kwargs
[
"size"
]
=
{
"shortest_edge"
:
3136
,
"longest_edge"
:
1003520
}
...
...
python/sglang/srt/models/llama.py
View file @
c1815a99
...
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
...
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
"Self attention has no KV cache scaling "
"factor attribute!"
"Self attention has no KV cache scaling "
"factor attribute!"
)
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
"""Get input embeddings from the model."""
return
self
.
embed_tokens
class
LlamaForCausalLM
(
nn
.
Module
):
class
LlamaForCausalLM
(
nn
.
Module
):
# BitandBytes specific attributes
# BitandBytes specific attributes
...
...
python/sglang/srt/models/sarashina2_vision.py
0 → 100644
View file @
c1815a99
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only Sarashina2Vision model compatible with HuggingFace weights."""
import
logging
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
LlamaConfig
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultimodalDataItem
,
MultimodalInputs
,
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.qwen2_vl
import
Qwen2VisionTransformer
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
class
Sarashina2VisionForCausalLM
(
nn
.
Module
):
"""
Sarashina2Vision model that combines:
- Llama text backbone (sbintuitions/sarashina2-7b)
- Qwen2VL vision encoder
"""
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
# Extract text and vision configurations
text_config
=
getattr
(
config
,
"text_config"
,
config
)
vision_config
=
getattr
(
config
,
"vision_config"
,
None
)
# Create vision transformer first (like original model)
if
vision_config
is
not
None
:
self
.
visual
=
Qwen2VisionTransformer
(
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"visual"
,
prefix
),
)
else
:
self
.
visual
=
None
# Layer norm for vision outputs (matching original model)
self
.
norm
=
nn
.
LayerNorm
(
text_config
.
hidden_size
)
# Create Llama text model (using 'llm' name to match original)
if
hasattr
(
text_config
,
"model_type"
)
and
text_config
.
model_type
==
"llama"
:
llama_config
=
LlamaConfig
(
**
text_config
.
__dict__
)
# Set vocab_size from main config if available
if
hasattr
(
config
,
"vocab_size"
):
llama_config
.
vocab_size
=
config
.
vocab_size
self
.
llm
=
LlamaForCausalLM
(
llama_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"llm"
,
prefix
),
)
else
:
# Set vocab_size from main config if available
if
hasattr
(
config
,
"vocab_size"
):
config
.
vocab_size
=
config
.
vocab_size
self
.
llm
=
LlamaForCausalLM
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"llm"
,
prefix
),
)
# Image token indices from config
self
.
image_token_index
=
getattr
(
config
,
"image_token_index"
,
14
)
self
.
start_image_token_index
=
getattr
(
config
,
"start_image_token_index"
,
102397
)
self
.
end_image_token_index
=
getattr
(
config
,
"end_image_token_index"
,
102398
)
# Ensure vocabulary size matches
if
hasattr
(
config
,
"vocab_size"
):
self
.
llm
.
config
.
vocab_size
=
config
.
vocab_size
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
"""Pad input tokens with multimodal data hashes for RadixAttention."""
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_input_embeddings
(
self
):
"""Get input embeddings from the language model."""
return
self
.
llm
.
get_input_embeddings
()
def
get_image_embeds
(
self
,
pixel_values
:
torch
.
Tensor
,
image_grid_thw
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Extract image embeddings using the vision transformer."""
if
self
.
visual
is
None
:
raise
ValueError
(
"Visual encoder not initialized"
)
# Use the existing Qwen2VisionTransformer forward method
hidden_states
=
self
.
visual
(
pixel_values
,
image_grid_thw
)
# Apply normalization layer
return
self
.
norm
(
hidden_states
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
"""Extract image features for SGLang compatibility."""
if
self
.
visual
is
None
:
raise
ValueError
(
"Visual encoder not initialized"
)
# Concatenate pixel values and grid_thw from all items
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thw
=
torch
.
cat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thw
.
dim
()
==
2
,
image_grid_thw
.
dim
()
# Use the get_image_embeds method
return
self
.
get_image_embeds
(
pixel_values
,
image_grid_thw
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Forward pass through the model."""
# Handles token-to-feature mapping for expanded tokens
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
.
model
,
multimodal_model
=
self
,
positions
=
positions
,
)
if
get_embedding
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
llm
.
lm_head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
"""Load model weights."""
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
# Collect weights that need to be fused
qkv_weights
=
{}
gate_up_weights
=
{}
for
name
,
loaded_weight
in
weights
:
# Handle weight name mappings
# Map visual attention weights: qkv -> qkv_proj
if
".attn.qkv."
in
name
:
mapped_name
=
name
.
replace
(
".attn.qkv."
,
".attn.qkv_proj."
)
if
mapped_name
in
params_dict
:
param
=
params_dict
[
mapped_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
mapped_name
)
continue
# Handle Llama attention weights - need to fuse q, k, v into qkv
if
".self_attn.q_proj.weight"
in
name
:
base
=
name
.
replace
(
".q_proj.weight"
,
""
)
qkv_weights
[
base
]
=
qkv_weights
.
get
(
base
,
{})
qkv_weights
[
base
][
"q"
]
=
loaded_weight
continue
elif
".self_attn.k_proj.weight"
in
name
:
base
=
name
.
replace
(
".k_proj.weight"
,
""
)
qkv_weights
[
base
]
=
qkv_weights
.
get
(
base
,
{})
qkv_weights
[
base
][
"k"
]
=
loaded_weight
continue
elif
".self_attn.v_proj.weight"
in
name
:
base
=
name
.
replace
(
".v_proj.weight"
,
""
)
qkv_weights
[
base
]
=
qkv_weights
.
get
(
base
,
{})
qkv_weights
[
base
][
"v"
]
=
loaded_weight
continue
# Handle Llama MLP weights - need to fuse gate and up projections
if
".mlp.gate_proj.weight"
in
name
:
base
=
name
.
replace
(
".gate_proj.weight"
,
""
)
gate_up_weights
[
base
]
=
gate_up_weights
.
get
(
base
,
{})
gate_up_weights
[
base
][
"gate"
]
=
loaded_weight
continue
elif
".mlp.up_proj.weight"
in
name
:
base
=
name
.
replace
(
".up_proj.weight"
,
""
)
gate_up_weights
[
base
]
=
gate_up_weights
.
get
(
base
,
{})
gate_up_weights
[
base
][
"up"
]
=
loaded_weight
continue
# Direct mapping for other weights
if
name
in
params_dict
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
# Fuse QKV weights for Llama attention layers
for
base
,
weights_dict
in
qkv_weights
.
items
():
if
"q"
in
weights_dict
and
"k"
in
weights_dict
and
"v"
in
weights_dict
:
qkv_name
=
f
"
{
base
}
.qkv_proj.weight"
if
qkv_name
in
params_dict
:
# Concatenate q, k, v weights
q
,
k
,
v
=
weights_dict
[
"q"
],
weights_dict
[
"k"
],
weights_dict
[
"v"
]
qkv
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
0
)
param
=
params_dict
[
qkv_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
qkv
)
loaded_params
.
add
(
qkv_name
)
# Fuse gate and up weights for Llama MLP layers
for
base
,
weights_dict
in
gate_up_weights
.
items
():
if
"gate"
in
weights_dict
and
"up"
in
weights_dict
:
gate_up_name
=
f
"
{
base
}
.gate_up_proj.weight"
if
gate_up_name
in
params_dict
:
# Concatenate gate and up weights
gate
,
up
=
weights_dict
[
"gate"
],
weights_dict
[
"up"
]
gate_up
=
torch
.
cat
([
gate
,
up
],
dim
=
0
)
param
=
params_dict
[
gate_up_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
gate_up
)
loaded_params
.
add
(
gate_up_name
)
# Register the model
EntryClass
=
Sarashina2VisionForCausalLM
python/sglang/srt/multimodal/processors/sarashina2_vision.py
0 → 100644
View file @
c1815a99
from
typing
import
List
,
Union
from
sglang.srt.models.sarashina2_vision
import
Sarashina2VisionForCausalLM
from
sglang.srt.multimodal.processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
)
class
Sarashina2VisionProcessor
(
BaseMultimodalProcessor
):
models
=
[
Sarashina2VisionForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
# Sarashina2Vision specific tokens (default is <|file|>)
self
.
IMAGE_TOKEN
=
"<|file|>"
self
.
IM_TOKEN_ID
=
getattr
(
hf_config
,
"image_token_index"
,
14
)
self
.
IM_START_ID
=
getattr
(
hf_config
,
"start_image_token_index"
,
102397
)
self
.
IM_END_ID
=
getattr
(
hf_config
,
"end_image_token_index"
,
102398
)
self
.
mm_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_id
=
self
.
IM_TOKEN_ID
,
).
build
(
_processor
)
# Patch the processor's image processor to handle parameter compatibility
if
hasattr
(
_processor
,
"image_processor"
)
and
hasattr
(
_processor
.
image_processor
,
"_preprocess"
):
original_preprocess
=
_processor
.
image_processor
.
_preprocess
def
patched_preprocess
(
*
args
,
**
kwargs
):
# Filter kwargs to only include parameters that the custom _preprocess method accepts
# Based on Sarashina2VisionImageProcessor._preprocess signature
allowed_params
=
{
"do_resize"
,
"resample"
,
"do_rescale"
,
"rescale_factor"
,
"do_normalize"
,
"image_mean"
,
"image_std"
,
"do_convert_rgb"
,
"data_format"
,
"input_data_format"
,
}
filtered_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
allowed_params
}
return
original_preprocess
(
*
args
,
**
filtered_kwargs
)
_processor
.
image_processor
.
_preprocess
=
patched_preprocess
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
*
args
,
**
kwargs
,
):
"""Process image data for Sarashina2Vision model using standard SGLang pattern."""
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
image_data
=
image_data
,
multimodal_tokens
=
self
.
mm_tokens
,
)
mm_items
,
input_ids
,
ret
=
self
.
process_and_combine_mm_data
(
base_output
=
base_output
,
mm_tokens
=
self
.
mm_tokens
,
)
return
{
"mm_items"
:
mm_items
,
"input_ids"
:
input_ids
.
tolist
(),
"im_token_id"
:
self
.
mm_tokens
.
image_token_id
,
"im_start_id"
:
self
.
IM_START_ID
,
"im_end_id"
:
self
.
IM_END_ID
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment