Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0d503090
"src/vscode:/vscode.git/clone" did not exist on "64abf1b19d8687601bd3be343066f9bb76223704"
Unverified
Commit
0d503090
authored
May 26, 2025
by
Lifu Huang
Committed by
GitHub
May 26, 2025
Browse files
Supported precomputed feature for Kimi VL (#6599)
parent
501efc3d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
93 additions
and
47 deletions
+93
-47
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+15
-1
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
+41
-22
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+2
-1
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+6
-23
test/srt/test_vlm_input_format.py
test/srt/test_vlm_input_format.py
+29
-0
No files found.
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
0d503090
...
@@ -5,7 +5,7 @@ import multiprocessing as mp
...
@@ -5,7 +5,7 @@ import multiprocessing as mp
import
os
import
os
import
re
import
re
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
...
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
)
return
ret
return
ret
@
staticmethod
def
_extract_processor_features
(
items
:
List
[
Any
],
attr_name
:
str
)
->
Optional
[
torch
.
Tensor
]:
"""
Helper function to concat extracted attributes from processor output.
"""
values
=
[
getattr
(
item
,
attr_name
)
for
item
in
items
if
getattr
(
item
,
attr_name
)
is
not
None
]
return
torch
.
concat
(
values
)
if
values
else
None
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
View file @
0d503090
from
typing
import
List
,
Union
import
re
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
BaseMultimodalProcessor
as
SGLangBaseProcessor
,
...
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<|media_pad|>"
self
.
IMAGE_TOKEN
=
"<|media_pad|>"
self
.
IMAGE_TOKEN_REGEX
=
re
.
compile
(
r
"(?:<\|media_pad\|>)+"
)
self
.
im_token_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
IMAGE_TOKEN
)
self
.
im_token_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
IMAGE_TOKEN
)
self
.
im_start
=
"<|media_start|>"
self
.
im_start_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
im_start
)
self
.
im_end
=
"<|media_end|>"
self
.
im_end_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
im_end
)
self
.
im_content
=
"<|media_content|>"
self
.
im_content_id
=
_processor
.
tokenizer
.
convert_tokens_to_ids
(
self
.
im_content
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
,
Dict
]],
input_text
,
input_text
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
...
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
...
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
prompt
=
input_text
,
prompt
=
input_text
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
IMAGE_TOKEN
,
image_token_regex
=
self
.
IMAGE_TOKEN_REGEX
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images_are_preprocessed
=
self
.
mm_inputs_are_preprocessed
(
base_output
.
images
)
images
=
base_output
.
images
,
if
not
images_are_preprocessed
:
)
ret
=
self
.
process_mm_data
(
input_ids
=
ret
[
"input_ids"
].
flatten
()
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
)
input_ids
=
ret
[
"input_ids"
].
flatten
()
image_grid_thws
=
ret
[
"image_grid_hws"
]
pixel_values
=
ret
[
"pixel_values"
]
precomputed_features
=
None
else
:
input_ids
=
self
.
_processor
.
tokenizer
(
base_output
.
input_text
,
return_tensors
=
"pt"
,
add_special_tokens
=
True
,
).
input_ids
.
flatten
()
image_grid_thws
=
self
.
_extract_processor_features
(
base_output
.
images
,
"image_grid_thws"
)
precomputed_features
=
self
.
_extract_processor_features
(
base_output
.
images
,
"precomputed_features"
)
pixel_values
=
self
.
_extract_processor_features
(
base_output
.
images
,
"pixel_values"
)
image_offsets
=
self
.
get_mm_items_offset
(
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
mm_token_id
=
self
.
im_token_id
,
mm_token_id
=
self
.
im_token_id
,
)
)
return
{
return
{
"input_ids"
:
input_ids
.
tolist
(),
"input_ids"
:
input_ids
.
tolist
(),
"mm_items"
:
[
"mm_items"
:
[
MultimodalDataItem
(
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
],
pixel_values
=
pixel_values
,
image_grid_thws
=
ret
[
"image_grid_hws"
],
image_grid_thws
=
image_grid_thws
,
precomputed_features
=
precomputed_features
,
modality
=
Modality
.
IMAGE
,
modality
=
Modality
.
IMAGE
,
image_offsets
=
image_offsets
,
image_offsets
=
image_offsets
,
)
)
],
],
"im_token_id"
:
self
.
im_token_id
,
"im_token_id"
:
self
.
im_token_id
,
"im_start_id"
:
self
.
im_start_id
,
"im_end_id"
:
self
.
im_end_id
,
"im_content_id"
:
self
.
im_content_id
,
}
}
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
0d503090
...
@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
audio_data
,
audio_data
=
audio_data
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
image_token
,
audio_token
=
self
.
audio_token
image_token
=
self
.
image_token
,
audio_token
=
self
.
audio_token
,
),
),
)
)
if
base_output
is
None
:
if
base_output
is
None
:
...
...
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
0d503090
...
@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if
base_output
.
images
:
if
base_output
.
images
:
if
images_are_preprocessed
:
if
images_are_preprocessed
:
all_image_grid_thws
=
[
image_grid_thw
=
self
.
_extract_processor_features
(
item
.
image_grid_thws
base_output
.
images
,
"image_grid_thws"
for
item
in
base_output
.
images
if
item
.
image_grid_thws
is
not
None
]
all_pixel_values
=
[
item
.
pixel_values
for
item
in
base_output
.
images
if
item
.
pixel_values
is
not
None
]
all_precomputed_features
=
[
item
.
precomputed_features
for
item
in
base_output
.
images
if
item
.
precomputed_features
is
not
None
]
image_grid_thw
=
(
torch
.
concat
(
all_image_grid_thws
)
if
all_image_grid_thws
else
None
)
)
p
ixel_values
=
(
p
recomputed_features
=
self
.
_extract_processor_features
(
torch
.
concat
(
all_pixel_values
)
if
all_pixel_values
else
None
base_output
.
images
,
"precomputed_features"
)
)
precomputed_features
=
(
pixel_values
=
self
.
_extract_processor_features
(
torch
.
concat
(
all_precomputed_features
)
base_output
.
images
,
"pixel_values"
if
all_precomputed_features
else
None
)
)
else
:
else
:
image_grid_thw
=
ret
[
"image_grid_thw"
]
image_grid_thw
=
ret
[
"image_grid_thw"
]
...
...
test/srt/test_vlm_input_format.py
View file @
0d503090
...
@@ -7,6 +7,7 @@ import requests
...
@@ -7,6 +7,7 @@ import requests
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
(
from
transformers
import
(
AutoModel
,
AutoProcessor
,
AutoProcessor
,
Gemma3ForConditionalGeneration
,
Gemma3ForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
...
@@ -51,6 +52,7 @@ class VLMInputTestBase:
...
@@ -51,6 +52,7 @@ class VLMInputTestBase:
mem_fraction_static
=
0.8
,
mem_fraction_static
=
0.8
,
enable_multimodal
=
True
,
enable_multimodal
=
True
,
disable_cuda_graph
=
True
,
disable_cuda_graph
=
True
,
trust_remote_code
=
True
,
)
)
def
tearDown
(
self
):
def
tearDown
(
self
):
...
@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
...
@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
)
)
class
TestKimiVLImageUnderstandsImage
(
VLMInputTestBase
,
unittest
.
IsolatedAsyncioTestCase
):
model_path
=
"moonshotai/Kimi-VL-A3B-Instruct"
chat_template
=
"kimi-vl"
@
classmethod
def
_init_visual
(
cls
):
model
=
AutoModel
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
)
cls
.
vision_tower
=
model
.
vision_tower
.
eval
().
to
(
cls
.
device
)
cls
.
mm_projector
=
model
.
multi_modal_projector
.
eval
().
to
(
cls
.
device
)
cls
.
visual
=
lambda
tokenizer_output
:
cls
.
mm_projector
(
cls
.
vision_tower
(
pixel_values
=
tokenizer_output
[
"pixel_values"
],
grid_hws
=
tokenizer_output
[
"image_grid_hws"
],
)
)
def
_pixel_values_image_data
(
self
,
processor_output
):
return
dict
(
modality
=
"IMAGE"
,
image_grid_thws
=
processor_output
[
"image_grid_hws"
],
pixel_values
=
processor_output
[
"pixel_values"
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment