Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8430bfe3
Unverified
Commit
8430bfe3
authored
Jul 20, 2025
by
Xinyuan Tong
Committed by
GitHub
Jul 20, 2025
Browse files
[Refactor] simplify multimodal data processing (#8107)
Signed-off-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
c9e8613c
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
126 additions
and
69 deletions
+126
-69
python/sglang/srt/multimodal/processors/phi4mm.py
python/sglang/srt/multimodal/processors/phi4mm.py
+6
-7
python/sglang/srt/multimodal/processors/pixtral.py
python/sglang/srt/multimodal/processors/pixtral.py
+13
-34
python/sglang/srt/multimodal/processors/qwen_audio.py
python/sglang/srt/multimodal/processors/qwen_audio.py
+65
-0
python/sglang/srt/multimodal/processors/qwen_vl.py
python/sglang/srt/multimodal/processors/qwen_vl.py
+0
-2
python/sglang/srt/multimodal/processors/vila.py
python/sglang/srt/multimodal/processors/vila.py
+0
-2
test/srt/test_vision_openai_server_a.py
test/srt/test_vision_openai_server_a.py
+17
-16
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+1
-0
test/srt/test_vision_openai_server_common.py
test/srt/test_vision_openai_server_common.py
+19
-3
test/srt/test_vlm_input_format.py
test/srt/test_vlm_input_format.py
+5
-5
No files found.
python/sglang/srt/multimodal/processors/phi4mm.py
View file @
8430bfe3
...
...
@@ -31,6 +31,7 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
for
hf_key
,
sglang_key
in
key_mapping
.
items
():
if
hf_key
in
result
:
result
[
sglang_key
]
=
result
[
hf_key
]
del
result
[
hf_key
]
# Filter out None or empty tensors from the result.
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
...
...
@@ -58,7 +59,7 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
self
.
AUDIO_TOKEN_ID
=
200011
self
.
AUDIO_SAMPLE_RATE
=
16000
self
.
m
ultimodal
_tokens
=
MultimodalSpecialTokens
(
self
.
m
m
_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_id
=
self
.
IM_TOKEN_ID
,
audio_token
=
self
.
AUDIO_TOKEN
,
...
...
@@ -71,15 +72,13 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
,
input_text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
image_data
=
image_data
,
multimodal_tokens
=
self
.
m
ultimodal
_tokens
,
multimodal_tokens
=
self
.
m
m
_tokens
,
audio_sample_rate
=
self
.
AUDIO_SAMPLE_RATE
,
)
...
...
@@ -91,12 +90,12 @@ class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
]
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
base_output
,
self
.
m
ultimodal
_tokens
base_output
,
self
.
m
m
_tokens
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
mm_items
,
"im_token_id"
:
self
.
IM_TOKEN_ID
,
"audio_token_id"
:
self
.
AUDIO_TOKEN_ID
,
"im_token_id"
:
self
.
mm_tokens
.
image_token_id
,
"audio_token_id"
:
self
.
mm_tokens
.
audio_token_id
,
}
python/sglang/srt/multimodal/processors/pixtral.py
View file @
8430bfe3
...
...
@@ -6,7 +6,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens
as
_get_pixtral_hf_num_image_tokens
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.pixtral
import
PixtralVisionModel
from
sglang.srt.multimodal.processors.base_processor
import
(
BaseMultimodalProcessor
,
...
...
@@ -45,7 +44,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
image_token_id
=
getattr
(
self
.
IM_TOKEN_ID
=
getattr
(
hf_config
,
"image_token_index"
,
PixtralVisionModel
.
DEFAULT_IMAGE_TOKEN_ID
)
# Instantiate the patcher logic helper using the class defined above
...
...
@@ -53,8 +52,9 @@ class PixtralProcessor(BaseMultimodalProcessor):
self
.
vision_config
=
hf_config
.
vision_config
self
.
image_size
=
self
.
vision_config
.
image_size
self
.
patch_size
=
self
.
vision_config
.
patch_size
self
.
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
_processor
.
image_token
self
.
mm_tokens
=
MultimodalSpecialTokens
(
image_token
=
_processor
.
image_token
,
image_token_id
=
self
.
IM_TOKEN_ID
,
).
build
(
_processor
)
_processor
.
tokenizer
.
add_special_tokens
(
{
...
...
@@ -80,42 +80,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
):
mm_data
=
self
.
load_mm_data
(
prompt
=
input_text
,
multimodal_tokens
=
self
.
multimodal_tokens
,
max_req_input_len
=
kwargs
.
get
(
"max_req_input_len"
,
4096
),
multimodal_tokens
=
self
.
mm_tokens
,
image_data
=
image_data
,
return_text
=
True
,
)
if
mm_data
.
images
:
resize_tasks
=
[
self
.
_resize
(
image
)
for
image
in
mm_data
.
images
]
mm_data
.
images
=
await
asyncio
.
gather
(
*
resize_tasks
)
processor_output
=
self
.
process_mm_data
(
input_text
=
mm_data
.
input_text
,
images
=
mm_data
.
images
,
mm_items
,
input_ids
,
_
=
self
.
process_and_combine_mm_data
(
mm_data
,
self
.
mm_tokens
)
if
"pixel_values"
in
processor_output
:
input_ids
=
processor_output
[
"input_ids"
].
view
(
-
1
)
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
mm_token_id
=
self
.
image_token_id
,
)
mm_items
=
[
MultimodalDataItem
(
feature
=
processor_output
[
"pixel_values"
],
image_sizes
=
processor_output
[
"image_sizes"
],
modality
=
Modality
.
IMAGE
,
offsets
=
image_offsets
,
)
]
input_ids
=
input_ids
.
tolist
()
processor_output
.
update
(
input_ids
=
input_ids
,
mm_items
=
mm_items
,
# there's no im_start_id for pixtral, only im_token and im_end_token
im_end_id
=
self
.
IMG_END_TOKEN_ID
,
im_token_id
=
self
.
image_token_id
,
)
return
processor_output
return
{
"mm_items"
:
mm_items
,
"input_ids"
:
input_ids
.
tolist
(),
"im_token_id"
:
self
.
IM_TOKEN_ID
,
"im_token"
:
self
.
_processor
.
image_token
,
}
python/sglang/srt/
managers/
multimodal
_
processors/qwen_audio.py
→
python/sglang/srt/multimodal
/
processors/qwen_audio.py
View file @
8430bfe3
import
re
from
typing
import
List
,
Union
import
torch
from
sglang.srt.
managers.
multimodal
_
processors.base_processor
import
(
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.qwen2_audio
import
Qwen2AudioForConditionalGeneration
from
sglang.srt.multimodal
.
processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.qwen2_audio
import
Qwen2AudioForConditionalGeneration
class
Qwen2AudioMultimodalProcessor
(
BaseMultimodalProcessor
):
...
...
@@ -20,75 +17,49 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
self
.
AUDIO_TOKEN_REGEX
=
re
.
compile
(
r
"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
)
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
self
.
audio_start_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|audio_bos|>"
)
self
.
audio_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|AUDIO|>"
)
self
.
audio_end_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|audio_eos|>"
)
self
.
mm_tokens
=
MultimodalSpecialTokens
(
audio_token
=
self
.
AUDIO_TOKEN
,
audio_token_regex
=
self
.
AUDIO_TOKEN_REGEX
,
audio_token_id
=
self
.
audio_token_id
,
).
build
(
_processor
)
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]]
,
audio_data
,
input_text
,
request_obj
,
max_req_input_len
,
**
kwargs
,
):
audio_data
=
request_obj
.
audio_data
if
not
isinstance
(
audio_data
,
list
):
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
audio_token
=
self
.
AUDIO_TOKEN
,
audio_token_regex
=
self
.
AUDIO_TOKEN_REGEX
,
),
multimodal_tokens
=
self
.
mm_tokens
,
)
if
base_output
is
None
:
return
None
res
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
audio
=
base_output
.
audios
,
mm_items
,
input_ids
,
ret
=
self
.
process_and_combine_mm_data
(
base_output
,
self
.
mm_tokens
)
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
audio_start_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|audio_bos|>"
)
audio_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|AUDIO|>"
)
audio_end_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|audio_eos|>"
)
items
=
[]
input_ids
=
res
[
"input_ids"
].
flatten
()
if
(
"input_features"
in
res
and
res
[
"input_features"
]
is
not
None
and
len
(
res
[
"input_features"
])
!=
0
):
if
audio_start_id
is
not
None
and
audio_end_id
is
not
None
:
audio_offsets
=
self
.
get_mm_items_offset_by_pair
(
input_ids
=
input_ids
,
mm_start_id
=
audio_start_id
,
mm_end_id
=
audio_end_id
,
)
else
:
audio_offsets
=
None
input_lengths
=
res
[
"feature_attention_mask"
].
sum
(
dim
=-
1
)
input_lengths
=
(
input_lengths
-
1
)
//
2
+
1
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
assert
(
"feature_attention_mask"
in
ret
),
"feature_attention_mask not found in processor output"
input_lengths
=
ret
[
"feature_attention_mask"
].
sum
(
dim
=-
1
)
input_lengths
=
(
input_lengths
-
1
)
//
2
+
1
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
item
=
MultimodalDataItem
(
feature
=
res
[
"input_features"
],
audio_feature_lens
=
output_lengths
,
audio_offsets
=
audio_offsets
,
modality
=
Modality
.
AUDIO
,
)
items
+=
[
item
]
mm_items
[
0
].
model_specific_data
[
"audio_feature_lens"
]
=
output_lengths
return
{
"mm_items"
:
items
,
"mm_items"
:
mm_
items
,
"input_ids"
:
input_ids
.
tolist
(),
"audio_start_id"
:
audio_start_id
,
"audio_token_id"
:
audio_token_id
,
"audio_end_id"
:
audio_end_id
,
"audio_start_id"
:
self
.
audio_start_id
,
"audio_token_id"
:
self
.
audio_token_id
,
"audio_end_id"
:
self
.
audio_end_id
,
}
python/sglang/srt/multimodal/processors/qwen_vl.py
View file @
8430bfe3
...
...
@@ -227,7 +227,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
...
...
@@ -237,7 +236,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_data
=
image_data
,
video_data
=
request_obj
.
video_data
,
multimodal_tokens
=
self
.
mm_tokens
,
max_req_input_len
=
max_req_input_len
,
)
# Qwen-specific: resize images if they are raw Image objects
...
...
python/sglang/srt/multimodal/processors/vila.py
View file @
8430bfe3
...
...
@@ -47,13 +47,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data
:
Optional
[
ImageDataInputItem
|
List
[
ImageDataInputItem
]],
input_text
:
str
|
List
[
int
],
request_obj
:
GenerateReqInput
|
EmbeddingReqInput
,
max_req_input_len
:
int
,
**
kwargs
,
)
->
Optional
[
Dict
[
str
,
Any
]]:
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
multimodal_tokens
=
self
.
mm_tokens
,
max_req_input_len
=
max_req_input_len
,
image_data
=
image_data
,
)
...
...
test/srt/test_vision_openai_server_a.py
View file @
8430bfe3
...
...
@@ -116,22 +116,23 @@ class TestVLMContextLengthIssue(CustomTestCase):
)
class
TestMllamaServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
)
cls
.
base_url
+=
"/v1"
def
test_video_chat_completion
(
self
):
pass
# Note(Xinyuan): mllama is not stable for now, skip for CI
# class TestMllamaServer(TestOpenAIVisionServer):
# @classmethod
# def setUpClass(cls):
# cls.model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# cls.base_url = DEFAULT_URL_FOR_TEST
# cls.api_key = "sk-123456"
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# api_key=cls.api_key,
# )
# cls.base_url += "/v1"
# def test_video_chat_completion(self):
# pass
class
TestMinicpmvServer
(
TestOpenAIVisionServer
):
...
...
test/srt/test_vision_openai_server_b.py
View file @
8430bfe3
...
...
@@ -67,6 +67,7 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
"--trust-remote-code"
,
"--context-length"
,
"4096"
,
"--disable-cuda-graph"
,
],
)
cls
.
base_url
+=
"/v1"
...
...
test/srt/test_vision_openai_server_common.py
View file @
8430bfe3
...
...
@@ -308,19 +308,35 @@ class TestOpenAIVisionServer(CustomTestCase):
"iPod"
in
video_response
or
"device"
in
video_response
or
"microphone"
in
video_response
),
video_response
),
f
"""
====================== video_response =====================
{
video_response
}
===========================================================
should contain 'iPod' or 'device' or 'microphone'
"""
assert
(
"man"
in
video_response
or
"person"
in
video_response
or
"individual"
in
video_response
or
"speaker"
in
video_response
),
video_response
or
"Steve"
in
video_response
),
f
"""
====================== video_response =====================
{
video_response
}
===========================================================
should contain 'man' or 'person' or 'individual' or 'speaker'
"""
assert
(
"present"
in
video_response
or
"examine"
in
video_response
or
"display"
in
video_response
or
"hold"
in
video_response
)
),
f
"""
====================== video_response =====================
{
video_response
}
===========================================================
should contain 'present' or 'examine' or 'display' or 'hold'
"""
assert
"black"
in
video_response
or
"dark"
in
video_response
self
.
assertIsNotNone
(
video_response
)
self
.
assertGreater
(
len
(
video_response
),
0
)
...
...
test/srt/test_vlm_input_format.py
View file @
8430bfe3
...
...
@@ -104,15 +104,15 @@ class VLMInputTestBase:
)
self
.
verify_response
(
output
)
async
def
test_understands_precomputed_
feature
s
(
self
):
async
def
test_understands_precomputed_
embedding
s
(
self
):
req
=
self
.
get_completion_request
()
processor_output
=
self
.
get_processor_output
(
req
=
req
)
with
torch
.
inference_mode
():
precomputed_
feature
s
=
self
.
__class__
.
visual
(
processor_output
)
precomputed_
embedding
s
=
self
.
__class__
.
visual
(
processor_output
)
output
=
await
self
.
engine
.
async_generate
(
input_ids
=
processor_output
[
"input_ids"
][
0
].
detach
().
cpu
().
tolist
(),
image_data
=
[
self
.
_precomputed_image_data
(
processor_output
,
precomputed_
feature
s
)
self
.
_precomputed_image_data
(
processor_output
,
precomputed_
embedding
s
)
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
...
...
@@ -128,11 +128,11 @@ class VLMInputTestBase:
)
self
.
verify_response
(
output
)
def
_precomputed_image_data
(
self
,
processor_output
,
precomputed_
feature
s
):
def
_precomputed_image_data
(
self
,
processor_output
,
precomputed_
embedding
s
):
"""This should not be overridden."""
return
dict
(
modality
=
"IMAGE"
,
precomputed_
feature
s
=
precomputed_
feature
s
,
precomputed_
embedding
s
=
precomputed_
embedding
s
,
)
def
_pixel_values_image_data
(
self
,
processor_output
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment