Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5e3d603
Unverified
Commit
b5e3d603
authored
Jul 10, 2025
by
Mick
Committed by
GitHub
Jul 09, 2025
Browse files
vlm: support video as an input modality (#5888)
parent
4ed57807
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
208 additions
and
125 deletions
+208
-125
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+21
-2
python/sglang/srt/entrypoints/openai/protocol.py
python/sglang/srt/entrypoints/openai/protocol.py
+11
-0
python/sglang/srt/entrypoints/openai/serving_chat.py
python/sglang/srt/entrypoints/openai/serving_chat.py
+7
-0
python/sglang/srt/jinja_template_utils.py
python/sglang/srt/jinja_template_utils.py
+8
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+25
-2
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+55
-94
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-5
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+13
-1
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+1
-1
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+1
-1
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+1
-1
python/sglang/srt/models/gemma3n_mm.py
python/sglang/srt/models/gemma3n_mm.py
+6
-3
python/sglang/srt/models/internvl.py
python/sglang/srt/models/internvl.py
+8
-2
python/sglang/srt/models/kimi_vl.py
python/sglang/srt/models/kimi_vl.py
+8
-2
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+3
-1
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-1
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+1
-2
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+1
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+13
-4
python/sglang/srt/models/phi4mm.py
python/sglang/srt/models/phi4mm.py
+8
-2
No files found.
python/sglang/srt/conversation.py
View file @
b5e3d603
...
@@ -88,9 +88,11 @@ class Conversation:
...
@@ -88,9 +88,11 @@ class Conversation:
stop_str
:
Union
[
str
,
List
[
str
]]
=
None
stop_str
:
Union
[
str
,
List
[
str
]]
=
None
# The string that represents an image token in the prompt
# The string that represents an image token in the prompt
image_token
:
str
=
"<image>"
image_token
:
str
=
"<image>"
video_token
:
str
=
"<video>"
audio_token
:
str
=
"<audio>"
audio_token
:
str
=
"<audio>"
image_data
:
Optional
[
List
[
str
]]
=
None
image_data
:
Optional
[
List
[
str
]]
=
None
video_data
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
stop_token_ids
:
Optional
[
int
]
=
None
stop_token_ids
:
Optional
[
int
]
=
None
...
@@ -380,11 +382,15 @@ class Conversation:
...
@@ -380,11 +382,15 @@ class Conversation:
self
.
messages
.
append
([
role
,
message
])
self
.
messages
.
append
([
role
,
message
])
def
append_image
(
self
,
image
:
str
):
def
append_image
(
self
,
image
:
str
):
"""Append a new m
ess
age."""
"""Append a new
i
mage."""
self
.
image_data
.
append
(
image
)
self
.
image_data
.
append
(
image
)
def
append_video
(
self
,
video
:
str
):
"""Append a new video."""
self
.
video_data
.
append
(
video
)
def
append_audio
(
self
,
audio
:
str
):
def
append_audio
(
self
,
audio
:
str
):
"""Append a new
message
."""
"""Append a new
audio
."""
self
.
audio_data
.
append
(
audio
)
self
.
audio_data
.
append
(
audio
)
def
update_last_message
(
self
,
message
:
str
):
def
update_last_message
(
self
,
message
:
str
):
...
@@ -433,6 +439,7 @@ class Conversation:
...
@@ -433,6 +439,7 @@ class Conversation:
sep2
=
self
.
sep2
,
sep2
=
self
.
sep2
,
stop_str
=
self
.
stop_str
,
stop_str
=
self
.
stop_str
,
image_token
=
self
.
image_token
,
image_token
=
self
.
image_token
,
video_token
=
self
.
video_token
,
audio_token
=
self
.
audio_token
,
audio_token
=
self
.
audio_token
,
)
)
...
@@ -495,8 +502,12 @@ def generate_embedding_convs(
...
@@ -495,8 +502,12 @@ def generate_embedding_convs(
sep2
=
conv_template
.
sep2
,
sep2
=
conv_template
.
sep2
,
stop_str
=
conv_template
.
stop_str
,
stop_str
=
conv_template
.
stop_str
,
image_data
=
[],
image_data
=
[],
video_data
=
[],
audio_data
=
[],
modalities
=
[],
modalities
=
[],
image_token
=
conv_template
.
image_token
,
image_token
=
conv_template
.
image_token
,
video_token
=
conv_template
.
video_token
,
audio_token
=
conv_template
.
audio_token
,
)
)
real_content
=
""
real_content
=
""
...
@@ -557,10 +568,12 @@ def generate_chat_conv(
...
@@ -557,10 +568,12 @@ def generate_chat_conv(
sep2
=
conv
.
sep2
,
sep2
=
conv
.
sep2
,
stop_str
=
conv
.
stop_str
,
stop_str
=
conv
.
stop_str
,
image_data
=
[],
image_data
=
[],
video_data
=
[],
audio_data
=
[],
audio_data
=
[],
modalities
=
[],
modalities
=
[],
image_token
=
conv
.
image_token
,
image_token
=
conv
.
image_token
,
audio_token
=
conv
.
audio_token
,
audio_token
=
conv
.
audio_token
,
video_token
=
conv
.
video_token
,
)
)
if
isinstance
(
request
.
messages
,
str
):
if
isinstance
(
request
.
messages
,
str
):
...
@@ -602,6 +615,7 @@ def generate_chat_conv(
...
@@ -602,6 +615,7 @@ def generate_chat_conv(
image_token
=
""
image_token
=
""
audio_token
=
conv
.
audio_token
audio_token
=
conv
.
audio_token
video_token
=
conv
.
video_token
for
content
in
message
.
content
:
for
content
in
message
.
content
:
if
content
.
type
==
"text"
:
if
content
.
type
==
"text"
:
if
num_image_url
>
16
:
if
num_image_url
>
16
:
...
@@ -614,6 +628,9 @@ def generate_chat_conv(
...
@@ -614,6 +628,9 @@ def generate_chat_conv(
else
:
else
:
real_content
+=
image_token
real_content
+=
image_token
conv
.
append_image
(
content
.
image_url
.
url
)
conv
.
append_image
(
content
.
image_url
.
url
)
elif
content
.
type
==
"video_url"
:
real_content
+=
video_token
conv
.
append_video
(
content
.
video_url
.
url
)
elif
content
.
type
==
"audio_url"
:
elif
content
.
type
==
"audio_url"
:
real_content
+=
audio_token
real_content
+=
audio_token
conv
.
append_audio
(
content
.
audio_url
.
url
)
conv
.
append_audio
(
content
.
audio_url
.
url
)
...
@@ -810,6 +827,7 @@ register_conv_template(
...
@@ -810,6 +827,7 @@ register_conv_template(
sep_style
=
SeparatorStyle
.
ADD_NEW_LINE_SINGLE
,
sep_style
=
SeparatorStyle
.
ADD_NEW_LINE_SINGLE
,
stop_str
=
[
"<|im_end|>"
],
stop_str
=
[
"<|im_end|>"
],
image_token
=
"<|vision_start|><|image_pad|><|vision_end|>"
,
image_token
=
"<|vision_start|><|image_pad|><|vision_end|>"
,
video_token
=
"<|vision_start|><|video_pad|><|vision_end|>"
,
)
)
)
)
...
@@ -870,6 +888,7 @@ register_conv_template(
...
@@ -870,6 +888,7 @@ register_conv_template(
sep_style
=
SeparatorStyle
.
ADD_NEW_LINE_SINGLE
,
sep_style
=
SeparatorStyle
.
ADD_NEW_LINE_SINGLE
,
stop_str
=
(
"<|im_end|>"
,
"<|endoftext|>"
),
stop_str
=
(
"<|im_end|>"
,
"<|endoftext|>"
),
image_token
=
"(<image>./</image>)"
,
image_token
=
"(<image>./</image>)"
,
video_token
=
"(<video>./</video>)"
,
)
)
)
)
...
...
python/sglang/srt/entrypoints/openai/protocol.py
View file @
b5e3d603
...
@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
...
@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail
:
Optional
[
Literal
[
"auto"
,
"low"
,
"high"
]]
=
"auto"
detail
:
Optional
[
Literal
[
"auto"
,
"low"
,
"high"
]]
=
"auto"
class
ChatCompletionMessageContentVideoURL
(
BaseModel
):
url
:
str
class
ChatCompletionMessageContentAudioURL
(
BaseModel
):
class
ChatCompletionMessageContentAudioURL
(
BaseModel
):
url
:
str
url
:
str
...
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
...
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
modalities
:
Optional
[
Literal
[
"image"
,
"multi-images"
,
"video"
]]
=
"image"
modalities
:
Optional
[
Literal
[
"image"
,
"multi-images"
,
"video"
]]
=
"image"
class
ChatCompletionMessageContentVideoPart
(
BaseModel
):
type
:
Literal
[
"video_url"
]
video_url
:
ChatCompletionMessageContentVideoURL
class
ChatCompletionMessageContentAudioPart
(
BaseModel
):
class
ChatCompletionMessageContentAudioPart
(
BaseModel
):
type
:
Literal
[
"audio_url"
]
type
:
Literal
[
"audio_url"
]
audio_url
:
ChatCompletionMessageContentAudioURL
audio_url
:
ChatCompletionMessageContentAudioURL
...
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
...
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
ChatCompletionMessageContentPart
=
Union
[
ChatCompletionMessageContentPart
=
Union
[
ChatCompletionMessageContentTextPart
,
ChatCompletionMessageContentTextPart
,
ChatCompletionMessageContentImagePart
,
ChatCompletionMessageContentImagePart
,
ChatCompletionMessageContentVideoPart
,
ChatCompletionMessageContentAudioPart
,
ChatCompletionMessageContentAudioPart
,
]
]
...
@@ -629,6 +639,7 @@ class MessageProcessingResult:
...
@@ -629,6 +639,7 @@ class MessageProcessingResult:
prompt_ids
:
Union
[
str
,
List
[
int
]]
prompt_ids
:
Union
[
str
,
List
[
int
]]
image_data
:
Optional
[
Any
]
image_data
:
Optional
[
Any
]
audio_data
:
Optional
[
Any
]
audio_data
:
Optional
[
Any
]
video_data
:
Optional
[
Any
]
modalities
:
List
[
str
]
modalities
:
List
[
str
]
stop
:
List
[
str
]
stop
:
List
[
str
]
tool_call_constraint
:
Optional
[
Any
]
=
None
tool_call_constraint
:
Optional
[
Any
]
=
None
python/sglang/srt/entrypoints/openai/serving_chat.py
View file @
b5e3d603
...
@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request
=
GenerateReqInput
(
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
**
prompt_kwargs
,
image_data
=
processed_messages
.
image_data
,
image_data
=
processed_messages
.
image_data
,
video_data
=
processed_messages
.
video_data
,
audio_data
=
processed_messages
.
audio_data
,
audio_data
=
processed_messages
.
audio_data
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
return_logprob
=
request
.
logprobs
,
return_logprob
=
request
.
logprobs
,
...
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_ids
=
[]
prompt_ids
=
[]
openai_compatible_messages
=
[]
openai_compatible_messages
=
[]
image_data
=
[]
image_data
=
[]
video_data
=
[]
audio_data
=
[]
audio_data
=
[]
modalities
=
[]
modalities
=
[]
...
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
msg_dict
,
msg_dict
,
template_content_format
,
template_content_format
,
image_data
,
image_data
,
video_data
,
audio_data
,
audio_data
,
modalities
,
modalities
,
)
)
...
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
stop
=
request
.
stop
stop
=
request
.
stop
image_data
=
image_data
if
image_data
else
None
image_data
=
image_data
if
image_data
else
None
audio_data
=
audio_data
if
audio_data
else
None
audio_data
=
audio_data
if
audio_data
else
None
video_data
=
video_data
if
video_data
else
None
modalities
=
modalities
if
modalities
else
[]
modalities
=
modalities
if
modalities
else
[]
return
MessageProcessingResult
(
return
MessageProcessingResult
(
prompt
=
prompt
,
prompt
=
prompt
,
prompt_ids
=
prompt_ids
,
prompt_ids
=
prompt_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
video_data
=
video_data
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
modalities
=
modalities
,
modalities
=
modalities
,
stop
=
stop
,
stop
=
stop
,
...
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt
=
conv
.
get_prompt
()
prompt
=
conv
.
get_prompt
()
image_data
=
conv
.
image_data
if
conv
.
image_data
else
None
image_data
=
conv
.
image_data
if
conv
.
image_data
else
None
video_data
=
conv
.
video_data
if
conv
.
video_data
else
None
audio_data
=
conv
.
audio_data
if
conv
.
audio_data
else
None
audio_data
=
conv
.
audio_data
if
conv
.
audio_data
else
None
modalities
=
conv
.
modalities
if
conv
.
modalities
else
[]
modalities
=
conv
.
modalities
if
conv
.
modalities
else
[]
stop
=
copy
.
copy
(
conv
.
stop_str
or
[]
if
not
request
.
ignore_eos
else
[])
stop
=
copy
.
copy
(
conv
.
stop_str
or
[]
if
not
request
.
ignore_eos
else
[])
...
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
...
@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt
=
prompt
,
prompt
=
prompt
,
prompt_ids
=
prompt_ids
,
prompt_ids
=
prompt_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
video_data
=
video_data
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
modalities
=
modalities
,
modalities
=
modalities
,
stop
=
stop
,
stop
=
stop
,
...
...
python/sglang/srt/jinja_template_utils.py
View file @
b5e3d603
...
@@ -110,6 +110,7 @@ def process_content_for_template_format(
...
@@ -110,6 +110,7 @@ def process_content_for_template_format(
msg_dict
:
dict
,
msg_dict
:
dict
,
content_format
:
str
,
content_format
:
str
,
image_data
:
list
,
image_data
:
list
,
video_data
:
list
,
audio_data
:
list
,
audio_data
:
list
,
modalities
:
list
,
modalities
:
list
,
)
->
dict
:
)
->
dict
:
...
@@ -120,6 +121,7 @@ def process_content_for_template_format(
...
@@ -120,6 +121,7 @@ def process_content_for_template_format(
msg_dict: Message dictionary with content
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
image_data: List to append extracted image URLs
video_data: List to append extracted video URLs
audio_data: List to append extracted audio URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
modalities: List to append modalities
...
@@ -143,6 +145,12 @@ def process_content_for_template_format(
...
@@ -143,6 +145,12 @@ def process_content_for_template_format(
modalities
.
append
(
chunk
.
get
(
"modalities"
))
modalities
.
append
(
chunk
.
get
(
"modalities"
))
# Normalize to simple 'image' type for template compatibility
# Normalize to simple 'image' type for template compatibility
processed_content_parts
.
append
({
"type"
:
"image"
})
processed_content_parts
.
append
({
"type"
:
"image"
})
elif
chunk_type
==
"video_url"
:
video_data
.
append
(
chunk
[
"video_url"
][
"url"
])
if
chunk
.
get
(
"modalities"
):
modalities
.
append
(
chunk
.
get
(
"modalities"
))
# Normalize to simple 'video' type for template compatibility
processed_content_parts
.
append
({
"type"
:
"video"
})
elif
chunk_type
==
"audio_url"
:
elif
chunk_type
==
"audio_url"
:
audio_data
.
append
(
chunk
[
"audio_url"
][
"url"
])
audio_data
.
append
(
chunk
[
"audio_url"
][
"url"
])
# Normalize to simple 'audio' type
# Normalize to simple 'audio' type
...
...
python/sglang/srt/managers/io_struct.py
View file @
b5e3d603
...
@@ -65,6 +65,8 @@ class GenerateReqInput:
...
@@ -65,6 +65,8 @@ class GenerateReqInput:
]
=
None
]
=
None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data
:
Optional
[
Union
[
List
[
AudioDataItem
],
AudioDataItem
]]
=
None
audio_data
:
Optional
[
Union
[
List
[
AudioDataItem
],
AudioDataItem
]]
=
None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data
:
Optional
[
Union
[
List
[
List
[
str
]],
List
[
str
],
str
]]
=
None
# The sampling_params. See descriptions below.
# The sampling_params. See descriptions below.
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
# The request id.
# The request id.
...
@@ -110,7 +112,11 @@ class GenerateReqInput:
...
@@ -110,7 +112,11 @@ class GenerateReqInput:
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_rank
:
Optional
[
int
]
=
None
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
return
(
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
video_data
)
or
has_valid_data
(
self
.
audio_data
)
)
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
"""
"""
...
@@ -232,6 +238,7 @@ class GenerateReqInput:
...
@@ -232,6 +238,7 @@ class GenerateReqInput:
self
.
_normalize_rid
(
num
)
self
.
_normalize_rid
(
num
)
self
.
_normalize_lora_paths
(
num
)
self
.
_normalize_lora_paths
(
num
)
self
.
_normalize_image_data
(
num
)
self
.
_normalize_image_data
(
num
)
self
.
_normalize_video_data
(
num
)
self
.
_normalize_audio_data
(
num
)
self
.
_normalize_audio_data
(
num
)
self
.
_normalize_sampling_params
(
num
)
self
.
_normalize_sampling_params
(
num
)
self
.
_normalize_logprob_params
(
num
)
self
.
_normalize_logprob_params
(
num
)
...
@@ -300,6 +307,15 @@ class GenerateReqInput:
...
@@ -300,6 +307,15 @@ class GenerateReqInput:
self
.
image_data
=
wrapped_images
*
self
.
parallel_sample_num
self
.
image_data
=
wrapped_images
*
self
.
parallel_sample_num
self
.
modalities
=
[
"image"
]
*
num
self
.
modalities
=
[
"image"
]
*
num
def
_normalize_video_data
(
self
,
num
):
"""Normalize video data for batch processing."""
if
self
.
video_data
is
None
:
self
.
video_data
=
[
None
]
*
num
elif
not
isinstance
(
self
.
video_data
,
list
):
self
.
video_data
=
[
self
.
video_data
]
*
num
elif
isinstance
(
self
.
video_data
,
list
):
self
.
video_data
=
self
.
video_data
*
self
.
parallel_sample_num
def
_normalize_audio_data
(
self
,
num
):
def
_normalize_audio_data
(
self
,
num
):
"""Normalize audio data for batch processing."""
"""Normalize audio data for batch processing."""
if
self
.
audio_data
is
None
:
if
self
.
audio_data
is
None
:
...
@@ -408,6 +424,7 @@ class GenerateReqInput:
...
@@ -408,6 +424,7 @@ class GenerateReqInput:
self
.
input_embeds
[
i
]
if
self
.
input_embeds
is
not
None
else
None
self
.
input_embeds
[
i
]
if
self
.
input_embeds
is
not
None
else
None
),
),
image_data
=
self
.
image_data
[
i
],
image_data
=
self
.
image_data
[
i
],
video_data
=
self
.
video_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
rid
=
self
.
rid
[
i
],
rid
=
self
.
rid
[
i
],
...
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
...
@@ -507,6 +524,8 @@ class EmbeddingReqInput:
image_data
:
Optional
[
image_data
:
Optional
[
Union
[
List
[
List
[
Union
[
Image
,
str
]]],
List
[
Union
[
Image
,
str
]],
Union
[
Image
,
str
]]
Union
[
List
[
List
[
Union
[
Image
,
str
]]],
List
[
Union
[
Image
,
str
]],
Union
[
Image
,
str
]]
]
=
None
]
=
None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
audio_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text or input_ids.
...
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
...
@@ -578,7 +597,11 @@ class EmbeddingReqInput:
return
self
.
rid
return
self
.
rid
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
return
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
audio_data
)
return
(
has_valid_data
(
self
.
image_data
)
or
has_valid_data
(
self
.
video_data
)
or
has_valid_data
(
self
.
audio_data
)
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
self
.
is_cross_encoder_request
:
if
self
.
is_cross_encoder_request
:
...
...
python/sglang/srt/managers/mm_utils.py
View file @
b5e3d603
...
@@ -4,7 +4,7 @@ Multi-modality utils
...
@@ -4,7 +4,7 @@ Multi-modality utils
import
hashlib
import
hashlib
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
...
@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
This function will replace the data-tokens in between with pad_values accordingly
This function will replace the data-tokens in between with pad_values accordingly
"""
"""
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
print
(
f
"
{
mm_inputs
.
mm_items
=
}
"
)
data_token_pairs
=
self
.
data_token_id_pairs
data_token_pairs
=
self
.
data_token_id_pairs
mm_inputs
.
data_offsets
=
[]
mm_inputs
.
data_offsets
=
[]
if
data_token_pairs
is
None
:
if
data_token_pairs
is
None
:
...
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
...
@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return
ret_input_ids
return
ret_input_ids
embedding_cache
=
None
embedding_cache
:
Optional
[
MultiModalCache
]
=
None
def
init_embedding_cache
(
max_size
:
int
):
def
init_embedding_cache
(
max_size
:
int
=
0
):
global
embedding_cache
global
embedding_cache
embedding_cache
=
MultiModalCache
(
max_size
)
embedding_cache
=
MultiModalCache
(
max_size
)
...
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
...
@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
continue
continue
embedding_items_per_req
=
embedding_items
[
items_size
[
i
]
:
items_size
[
i
+
1
]]
embedding_items_per_req
=
embedding_items
[
items_size
[
i
]
:
items_size
[
i
+
1
]]
items_offset
=
items_offset_list
[
i
]
items_offset
=
items_offset_list
[
i
]
assert
items_offset
is
not
None
,
items_offset
embedding_items_hash
=
get_embedding_hash
(
embedding_items_per_req
)
embedding_items_hash
=
get_embedding_hash
(
embedding_items_per_req
)
# if all items has been prefixed, we do not need to calculate embedding
# if all items has been prefixed, we do not need to calculate embedding
if
all
([
offset_end
<
prefix_length
[
i
]
for
_
,
offset_end
in
items_offset
]):
if
all
([
offset_end
<
prefix_length
[
i
]
for
_
,
offset_end
in
items_offset
]):
...
@@ -380,11 +382,9 @@ def embed_mm_inputs(
...
@@ -380,11 +382,9 @@ def embed_mm_inputs(
extend_seq_lens
:
List
[
int
],
extend_seq_lens
:
List
[
int
],
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
input_embedding
:
nn
.
Embedding
,
input_embedding
:
nn
.
Embedding
,
image_data_embedding_func
:
Callable
[
multimodal_model
:
nn
.
Module
=
None
,
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
data_embedding_func_mapping
:
Dict
[
]
=
None
,
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
audio_data_embedding_func
:
Callable
[
[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
=
None
,
]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
placeholder_tokens
:
dict
[
Modality
,
List
[
int
]]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -397,8 +397,6 @@ def embed_mm_inputs(
...
@@ -397,8 +397,6 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
Returns:
...
@@ -415,88 +413,53 @@ def embed_mm_inputs(
...
@@ -415,88 +413,53 @@ def embed_mm_inputs(
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
item_flatten_list
+=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
is
not
None
]
embeddings
,
masks
=
[],
[]
embeddings
,
masks
=
[],
[]
# 2. Get multimodal embedding separately
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get mm embedding if any
# Try get image embedding if any
for
modality
in
Modality
.
all
():
if
(
items
=
[
any
(
True
for
item
in
item_flatten_list
if
item
.
is_image
())
item
for
item
in
item_flatten_list
if
item
.
is_modality
(
modality
=
modality
)
and
image_data_embedding_func
]
):
embedder
=
(
items
=
[
item
for
item
in
item_flatten_list
if
item
.
is_image
()]
None
placeholder_tensor
=
torch
.
tensor
(
if
data_embedding_func_mapping
is
None
[
item
.
pad_value
for
item
in
items
],
else
data_embedding_func_mapping
.
get
(
modality
,
None
)
device
=
input_ids
.
device
,
)
)
# calculate per request items length offset
if
embedder
is
None
:
items_size
=
torch
.
zeros
(
len
(
mm_inputs_list
)
+
1
,
dtype
=
int
)
# "image", "video", etc
items_offsets
=
[]
modality_id
=
modality
.
name
.
lower
()
for
i
,
mm_inputs
in
enumerate
(
mm_inputs_list
):
embedder
=
getattr
(
multimodal_model
,
f
"get_
{
modality_id
}
_feature"
,
None
)
image_items
=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
.
is_image
()]
if
len
(
items
)
!=
0
and
embedder
is
not
None
:
items_size
[
i
+
1
]
=
len
(
image_items
)
placeholder_tensor
=
torch
.
tensor
(
items_offsets
.
append
(
[
item
.
pad_value
for
item
in
items
],
flatten_nested_list
(
device
=
input_ids
.
device
,
[
item
.
image_offsets
for
item
in
mm_inputs
.
mm_items
if
item
.
is_image
()
]
)
)
)
items_size
=
torch
.
cumsum
(
items_size
,
dim
=
0
).
tolist
()
# calculate per request items length offset
items_size
=
torch
.
zeros
(
len
(
mm_inputs_list
)
+
1
,
dtype
=
int
)
embedding
,
mask
=
get_embedding_and_mask
(
items_offsets
=
[]
data_embedding_func
=
image_data_embedding_func
,
for
i
,
mm_inputs
in
enumerate
(
mm_inputs_list
):
embedding_items
=
items
,
mm_items
=
[
placeholder_tensor
=
placeholder_tensor
,
item
input_ids
=
input_ids
,
for
item
in
mm_inputs
.
mm_items
items_size
=
items_size
,
if
item
.
is_modality
(
modality
=
modality
)
prefix_length
=
extend_prefix_lens
,
]
extend_length
=
extend_seq_lens
,
items_size
[
i
+
1
]
=
len
(
mm_items
)
items_offset_list
=
items_offsets
,
items_offsets
.
append
(
)
flatten_nested_list
([
item
.
offsets
for
item
in
mm_inputs
.
mm_items
])
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
# Try get audio embedding if any
if
(
any
(
True
for
item
in
item_flatten_list
if
item
.
is_audio
())
and
audio_data_embedding_func
):
items
=
[
item
for
item
in
item_flatten_list
if
item
.
is_audio
()]
placeholder_tensor
=
torch
.
tensor
(
[
item
.
pad_value
for
item
in
items
],
device
=
input_ids
.
device
,
)
items_offsets
=
[]
# calculate per request items length offset
items_size
=
torch
.
zeros
(
len
(
mm_inputs_list
)
+
1
,
dtype
=
int
)
for
i
,
mm_inputs
in
enumerate
(
mm_inputs_list
):
audio_items
=
[
item
for
item
in
mm_inputs
.
mm_items
if
item
.
is_audio
()]
items_size
[
i
+
1
]
=
len
(
audio_items
)
items_offsets
.
append
(
flatten_nested_list
(
[
item
.
audio_offsets
for
item
in
mm_inputs
.
mm_items
if
item
.
is_audio
()
]
)
)
items_size
=
torch
.
cumsum
(
items_size
,
dim
=
0
).
tolist
()
embedding
,
mask
=
get_embedding_and_mask
(
data_embedding_func
=
embedder
,
embedding_items
=
items
,
placeholder_tensor
=
placeholder_tensor
,
input_ids
=
input_ids
,
items_size
=
items_size
,
prefix_length
=
extend_prefix_lens
,
extend_length
=
extend_seq_lens
,
items_offset_list
=
items_offsets
,
)
)
items_size
=
torch
.
cumsum
(
items_size
,
dim
=
0
)
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
embedding
,
mask
=
get_embedding_and_mask
(
data_embedding_func
=
audio_data_embedding_func
,
embedding_items
=
items
,
placeholder_tensor
=
placeholder_tensor
,
input_ids
=
input_ids
,
items_size
=
items_size
,
prefix_length
=
extend_prefix_lens
,
extend_length
=
extend_seq_lens
,
items_offset_list
=
items_offsets
,
)
embeddings
+=
[
embedding
]
masks
+=
[
mask
]
# 3. Get input embeddings
# 3. Get input embeddings
vocab_size
=
input_embedding
.
num_embeddings
vocab_size
=
input_embedding
.
num_embeddings
...
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
...
@@ -523,11 +486,9 @@ def general_mm_embed_routine(
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
language_model
:
nn
.
Module
,
language_model
:
nn
.
Module
,
image_data_embedding_func
:
Optional
[
multimodal_model
:
Optional
[
nn
.
Module
]
=
None
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
data_embedding_funcs
:
Dict
[
]
=
None
,
Modality
,
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
audio_data_embedding_func
:
Optional
[
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
]
]
=
None
,
]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
placeholder_tokens
:
Optional
[
dict
[
Modality
,
List
[
int
]]]
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
...
@@ -572,8 +533,8 @@ def general_mm_embed_routine(
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_embedding
=
embed_tokens
,
input_embedding
=
embed_tokens
,
image_data_embedding_func
=
image_data_embedding_func
,
multimodal_model
=
multimodal_model
,
audio_
data_embedding_func
=
audio_
data_embedding_func
,
data_embedding_func
_mapping
=
data_embedding_func
s
,
placeholder_tokens
=
placeholder_tokens
,
placeholder_tokens
=
placeholder_tokens
,
)
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b5e3d603
...
@@ -185,6 +185,10 @@ class Modality(Enum):
...
@@ -185,6 +185,10 @@ class Modality(Enum):
f
"Invalid modality string:
{
modality_str
}
. Valid modalities are:
{
[
m
.
name
for
m
in
Modality
]
}
"
f
"Invalid modality string:
{
modality_str
}
. Valid modalities are:
{
[
m
.
name
for
m
in
Modality
]
}
"
)
)
@
staticmethod
def
all
():
return
[
Modality
.
IMAGE
,
Modality
.
VIDEO
,
Modality
.
AUDIO
]
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultimodalDataItem
:
class
MultimodalDataItem
:
...
@@ -200,7 +204,7 @@ class MultimodalDataItem:
...
@@ -200,7 +204,7 @@ class MultimodalDataItem:
hash
:
int
=
None
hash
:
int
=
None
pad_value
:
int
=
None
pad_value
:
int
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_
offsets
:
Optional
[
list
]
=
None
offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
# data: Union[List[torch.Tensor], List[np.ndarray]]
...
@@ -253,12 +257,17 @@ class MultimodalDataItem:
...
@@ -253,12 +257,17 @@ class MultimodalDataItem:
self
.
hash
=
hash_feature
(
self
.
audio_features
)
self
.
hash
=
hash_feature
(
self
.
audio_features
)
elif
self
.
input_features
is
not
None
:
elif
self
.
input_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
input_features
)
self
.
hash
=
hash_feature
(
self
.
input_features
)
elif
self
.
is_video
():
self
.
hash
=
hash_feature
(
self
.
pixel_values_videos
)
else
:
else
:
self
.
hash
=
hash_feature
(
self
.
pixel_values
)
self
.
hash
=
hash_feature
(
self
.
pixel_values
)
assert
self
.
hash
is
not
None
assert
self
.
hash
is
not
None
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
def
is_modality
(
self
,
modality
:
Modality
)
->
bool
:
return
self
.
modality
==
modality
def
is_audio
(
self
):
def
is_audio
(
self
):
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
(
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
(
self
.
precomputed_features
is
not
None
self
.
precomputed_features
is
not
None
...
@@ -268,7 +277,7 @@ class MultimodalDataItem:
...
@@ -268,7 +277,7 @@ class MultimodalDataItem:
def
is_image
(
self
):
def
is_image
(
self
):
return
(
return
(
self
.
modality
==
Modality
.
IMAGE
or
self
.
modality
==
Modality
.
MULTI_IMAGES
self
.
is_
modality
(
Modality
.
IMAGE
)
or
self
.
is_
modality
(
Modality
.
MULTI_IMAGES
)
)
and
(
)
and
(
self
.
precomputed_features
is
not
None
self
.
precomputed_features
is
not
None
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
...
@@ -277,7 +286,7 @@ class MultimodalDataItem:
...
@@ -277,7 +286,7 @@ class MultimodalDataItem:
def
is_video
(
self
):
def
is_video
(
self
):
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
(
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
(
self
.
precomputed_features
is
not
None
self
.
precomputed_features
is
not
None
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
_videos
)
)
)
def
is_valid
(
self
)
->
bool
:
def
is_valid
(
self
)
->
bool
:
...
@@ -351,6 +360,7 @@ class MultimodalInputs:
...
@@ -351,6 +360,7 @@ class MultimodalInputs:
"im_token_id"
,
"im_token_id"
,
"im_start_id"
,
"im_start_id"
,
"im_end_id"
,
"im_end_id"
,
"video_token_id"
,
"slice_start_id"
,
"slice_start_id"
,
"slice_end_id"
,
"slice_end_id"
,
"audio_start_id"
,
"audio_start_id"
,
...
@@ -364,11 +374,12 @@ class MultimodalInputs:
...
@@ -364,11 +374,12 @@ class MultimodalInputs:
return
ret
return
ret
def
contains_image_inputs
(
self
)
->
bool
:
def
contains_image_inputs
(
self
)
->
bool
:
""" """
return
any
(
item
.
is_image
()
for
item
in
self
.
mm_items
)
return
any
(
item
.
is_image
()
for
item
in
self
.
mm_items
)
def
contains_video_inputs
(
self
)
->
bool
:
return
any
(
item
.
is_video
()
for
item
in
self
.
mm_items
)
def
contains_audio_inputs
(
self
)
->
bool
:
def
contains_audio_inputs
(
self
)
->
bool
:
""" """
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
def
contains_mm_input
(
self
)
->
bool
:
def
contains_mm_input
(
self
)
->
bool
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b5e3d603
...
@@ -453,8 +453,20 @@ class ForwardBatch:
...
@@ -453,8 +453,20 @@ class ForwardBatch:
for
mm_input
in
self
.
mm_inputs
for
mm_input
in
self
.
mm_inputs
)
)
def
contains_video_inputs
(
self
)
->
bool
:
if
self
.
mm_inputs
is
None
:
return
False
return
any
(
mm_input
is
not
None
and
mm_input
.
contains_video_inputs
()
for
mm_input
in
self
.
mm_inputs
)
def
contains_mm_inputs
(
self
)
->
bool
:
def
contains_mm_inputs
(
self
)
->
bool
:
return
self
.
contains_audio_inputs
()
or
self
.
contains_image_inputs
()
return
(
self
.
contains_audio_inputs
()
or
self
.
contains_video_inputs
()
or
self
.
contains_image_inputs
()
)
def
_compute_mrope_positions
(
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
b5e3d603
...
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
hidden_states
=
general_mm_embed_routine
(
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
b5e3d603
...
@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
)
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
b5e3d603
...
@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
input_ids
=
llm_input_ids
,
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/gemma3n_mm.py
View file @
b5e3d603
import
logging
import
logging
import
re
import
re
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
...
@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalDataItem
,
MultimodalInputs
,
MultimodalInputs
,
flatten_nested_list
,
flatten_nested_list
,
...
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
...
@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
audio_data_embedding_func
=
self
.
get_audio_feature
,
Modality
.
IMAGE
:
self
.
get_image_feature
,
Modality
.
AUDIO
:
self
.
get_audio_feature
,
},
positions
=
positions
,
positions
=
positions
,
per_layer_inputs
=
per_layer_inputs
,
per_layer_inputs
=
per_layer_inputs
,
)
)
...
...
python/sglang/srt/models/internvl.py
View file @
b5e3d603
...
@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
...
@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_janus_pro
import
DropPath
from
sglang.srt.models.deepseek_janus_pro
import
DropPath
...
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
...
@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/kimi_vl.py
View file @
b5e3d603
...
@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
...
@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens
,
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
...
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
...
@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/llava.py
View file @
b5e3d603
...
@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
...
@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
get_embedding
=
get_embedding
,
get_embedding
=
get_embedding
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
placeholder_tokens
=
None
,
# using mm_item.pad_value
placeholder_tokens
=
None
,
# using mm_item.pad_value
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/llavavid.py
View file @
b5e3d603
...
@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
)
)
image_offsets
=
[
image_offsets
=
[
flatten_nested_list
(
flatten_nested_list
(
[
item
.
image_
offsets
for
item
in
image_inputs
[
i
].
mm_items
]
[
item
.
offsets
for
item
in
image_inputs
[
i
].
mm_items
]
)
)
for
i
in
range
(
bs
)
for
i
in
range
(
bs
)
if
need_vision
[
i
]
if
need_vision
[
i
]
...
...
python/sglang/srt/models/minicpmo.py
View file @
b5e3d603
...
@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
...
@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
positions
=
positions
,
positions
=
positions
,
)
)
return
hidden_states
return
hidden_states
...
...
python/sglang/srt/models/minicpmv.py
View file @
b5e3d603
...
@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
...
@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
hidden_states
=
general_mm_embed_routine
(
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
image_data_embedding_func
=
self
.
get_image_feature
,
multimodal_model
=
self
,
language_model
=
self
.
llm
,
language_model
=
self
.
llm
,
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/mllama4.py
View file @
b5e3d603
...
@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
...
@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Llama4Config
,
Llama4VisionModel
from
transformers
import
Llama4Config
from
transformers.models.llama4.modeling_llama4
import
Llama4MultiModalProjector
from
transformers.models.llama4.modeling_llama4
import
(
Llama4MultiModalProjector
,
Llama4VisionModel
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
...
@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
...
@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens
,
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cpu
from
sglang.srt.utils
import
add_prefix
,
is_cpu
...
@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
image_embedding_func
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
positions
=
positions
,
positions
=
positions
,
)
)
...
...
python/sglang/srt/models/phi4mm.py
View file @
b5e3d603
...
@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
...
@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens
,
MultiModalityDataPaddingPatternMultimodalTokens
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.idefics2
import
Idefics2VisionTransformer
from
sglang.srt.models.idefics2
import
Idefics2VisionTransformer
...
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
...
@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
data_embedding_funcs
=
{
Modality
.
IMAGE
:
self
.
get_image_feature
,
},
positions
=
positions
,
positions
=
positions
,
)
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment