Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1e86457c
Unverified
Commit
1e86457c
authored
Mar 25, 2025
by
Mick
Committed by
GitHub
Mar 24, 2025
Browse files
model: Minicpmo (#3023)
parent
64129fa6
Changes
38
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
2265 additions
and
125 deletions
+2265
-125
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+10
-9
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+38
-25
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+4
-5
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+4
-4
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+6
-7
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-1
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+3
-3
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+3
-3
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+1995
-0
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+11
-23
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+4
-4
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+4
-5
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+7
-8
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+9
-1
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+12
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+31
-4
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+108
-10
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+15
-12
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
1e86457c
...
...
@@ -16,7 +16,6 @@
import
asyncio
import
copy
import
dataclasses
import
json
import
logging
import
os
import
pickle
...
...
@@ -52,10 +51,6 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.disaggregation.conn
import
KVBootstrapServer
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.image_processor
import
(
get_dummy_image_processor
,
get_image_processor
,
)
from
sglang.srt.managers.io_struct
import
(
AbortReq
,
BatchEmbeddingOut
,
...
...
@@ -93,6 +88,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqOutput
,
)
from
sglang.srt.managers.multimodal_processor
import
(
get_dummy_processor
,
get_mm_processor
,
import_processors
,
)
from
sglang.srt.metrics.collector
import
TokenizerMetricsCollector
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
@@ -171,6 +171,7 @@ class TokenizerManager:
self
.
image_token_id
=
self
.
model_config
.
image_token_id
if
self
.
model_config
.
is_multimodal
:
import_processors
()
_processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
...
...
@@ -179,9 +180,9 @@ class TokenizerManager:
)
# We want to parallelize the image pre-processing so we create an executor for it
# We create
image
_processor for any skip_tokenizer_init to make sure we still encode
# We create
mm
_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self
.
image
_processor
=
get_
image
_processor
(
self
.
mm
_processor
=
get_
mm
_processor
(
self
.
model_config
.
hf_config
,
server_args
,
_processor
)
...
...
@@ -192,7 +193,7 @@ class TokenizerManager:
self
.
tokenizer
=
self
.
processor
.
tokenizer
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
else
:
self
.
image
_processor
=
get_dummy_
image_
processor
()
self
.
mm
_processor
=
get_dummy_processor
()
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
...
...
@@ -389,7 +390,7 @@ class TokenizerManager:
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
image_inputs
:
Dict
=
await
self
.
image
_processor
.
process_
images
_async
(
image_inputs
:
Dict
=
await
self
.
mm
_processor
.
process_
mm_data
_async
(
obj
.
image_data
,
input_text
or
input_ids
,
obj
,
self
.
max_req_input_len
)
if
image_inputs
and
"input_ids"
in
image_inputs
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
1e86457c
...
...
@@ -43,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
MultimodalInputs
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
...
...
@@ -176,7 +176,7 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# For multimodal
image
_inputs
:
Optional
[
List
[
Image
Inputs
]]
=
None
mm
_inputs
:
Optional
[
List
[
Multimodal
Inputs
]]
=
None
# Encoder-decoder
encoder_cached
:
Optional
[
List
[
bool
]]
=
None
...
...
@@ -242,7 +242,7 @@ class ForwardBatch:
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
out_cache_loc
=
batch
.
out_cache_loc
,
image
_inputs
=
batch
.
image
_inputs
,
mm
_inputs
=
batch
.
multimodal
_inputs
,
encoder_cached
=
batch
.
encoder_cached
,
encoder_lens
=
batch
.
encoder_lens
,
encoder_lens_cpu
=
batch
.
encoder_lens_cpu
,
...
...
@@ -332,42 +332,53 @@ class ForwardBatch:
return
ret
def
merge_
image
_inputs
(
self
)
->
Optional
[
Image
Inputs
]:
def
merge_
mm
_inputs
(
self
)
->
Optional
[
Multimodal
Inputs
]:
"""
Merge all image inputs in the batch into a single
Image
Inputs object.
Merge all image inputs in the batch into a single
MultiModal
Inputs object.
Returns:
if none, current batch contains no image input
"""
if
not
self
.
image
_inputs
or
all
(
x
is
None
for
x
in
self
.
image
_inputs
):
if
not
self
.
mm
_inputs
or
all
(
x
is
None
for
x
in
self
.
mm
_inputs
):
return
None
# Filter out None values
valid_inputs
=
[
x
for
x
in
self
.
image
_inputs
if
x
is
not
None
]
valid_inputs
=
[
x
for
x
in
self
.
mm
_inputs
if
x
is
not
None
]
# Start with the first valid image input
merged
=
valid_inputs
[
0
]
# Merge remaining inputs
for
img
_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
img
_input
)
for
mm
_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
mm
_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
if
isinstance
(
merged
.
audio_features
,
np
.
ndarray
):
merged
.
audio_features
=
torch
.
from_numpy
(
merged
.
audio_features
)
return
merged
def
contains_image_inputs
(
self
)
->
bool
:
""" """
if
self
.
image_inputs
is
None
:
return
True
if
self
.
mm_inputs
is
None
:
return
False
return
any
(
image_input
.
pixel_values
is
not
None
and
image_input
.
pixel_values
is
not
[]
for
image_input
in
self
.
image_inputs
if
image_input
is
not
None
mm_input
is
not
None
and
mm_input
.
contains_image_inputs
()
for
mm_input
in
self
.
mm_inputs
)
def
contains_audio_inputs
(
self
)
->
bool
:
if
self
.
mm_inputs
is
None
:
return
False
return
any
(
mm_input
is
not
None
and
mm_input
.
contains_audio_inputs
()
for
mm_input
in
self
.
mm_inputs
)
def
contains_mm_inputs
(
self
)
->
bool
:
return
self
.
contains_audio_inputs
()
or
self
.
contains_image_inputs
()
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
...
...
@@ -378,8 +389,8 @@ class ForwardBatch:
for
i
,
_
in
enumerate
(
mrope_positions_list
):
mrope_position_delta
=
(
0
if
batch
.
image
_inputs
[
i
]
is
None
else
batch
.
image
_inputs
[
i
].
mrope_position_delta
if
batch
.
multimodal
_inputs
[
i
]
is
None
else
batch
.
multimodal
_inputs
[
i
].
mrope_position_delta
)
mrope_positions_list
[
i
]
=
MRotaryEmbedding
.
get_next_input_positions
(
mrope_position_delta
,
...
...
@@ -388,13 +399,13 @@ class ForwardBatch:
)
elif
self
.
forward_mode
.
is_extend
():
extend_start_loc_cpu
=
self
.
extend_start_loc
.
cpu
().
numpy
()
for
i
,
image
_inputs
in
enumerate
(
batch
.
image
_inputs
):
for
i
,
multimodal
_inputs
in
enumerate
(
batch
.
multimodal
_inputs
):
extend_start_loc
,
extend_seq_len
,
extend_prefix_len
=
(
extend_start_loc_cpu
[
i
],
batch
.
extend_seq_lens
[
i
],
batch
.
extend_prefix_lens
[
i
],
)
if
image
_inputs
is
None
:
if
multimodal
_inputs
is
None
:
# text only
mrope_positions
=
[
[
...
...
@@ -411,20 +422,22 @@ class ForwardBatch:
input_tokens
=
self
.
input_ids
[
extend_start_loc
:
extend_start_loc
+
extend_seq_len
],
image_grid_thw
=
image
_inputs
.
image_grid_thws
,
video_grid_thw
=
image
_inputs
.
video_grid_thws
,
image_token_id
=
image
_inputs
.
im_token_id
,
video_token_id
=
image
_inputs
.
video_token_id
,
image_grid_thw
=
multimodal
_inputs
.
image_grid_thws
,
video_grid_thw
=
multimodal
_inputs
.
video_grid_thws
,
image_token_id
=
multimodal
_inputs
.
im_token_id
,
video_token_id
=
multimodal
_inputs
.
video_token_id
,
vision_start_token_id
=
hf_config
.
vision_start_token_id
,
vision_end_token_id
=
hf_config
.
vision_end_token_id
,
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
,
context_len
=
0
,
seq_len
=
len
(
self
.
input_ids
),
second_per_grid_ts
=
image
_inputs
.
second_per_grid_ts
,
second_per_grid_ts
=
multimodal
_inputs
.
second_per_grid_ts
,
tokens_per_second
=
hf_config
.
vision_config
.
tokens_per_second
,
)
)
batch
.
image_inputs
[
i
].
mrope_position_delta
=
mrope_position_delta
batch
.
multimodal_inputs
[
i
].
mrope_position_delta
=
(
mrope_position_delta
)
mrope_positions_list
[
i
]
=
mrope_positions
self
.
mrope_positions
=
torch
.
cat
(
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
1e86457c
...
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
,
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -1959,7 +1959,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
get_image_feature
(
self
,
image_input
:
Image
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
image_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
...
...
@@ -1988,10 +1988,9 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image
_embedding_func
=
self
.
get_image_feature
,
mm_data
_embedding_func
=
self
.
get_image_feature
,
)
return
self
.
language_model
(
...
...
@@ -2005,7 +2004,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
im_start_id
=
image_inputs
.
im_start_id
im_end_id
=
image_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
1e86457c
...
...
@@ -11,7 +11,7 @@ from sglang.srt.configs.deepseekvl2 import (
)
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2ForCausalLM
...
...
@@ -222,7 +222,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
for
idx
,
image
in
enumerate
(
forward_batch
.
image
_inputs
):
for
idx
,
image
in
enumerate
(
forward_batch
.
mm
_inputs
):
if
image
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
idx
]
...
...
@@ -262,10 +262,10 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weights_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
return
input_ids
def
get_image_feature
(
self
,
image_input
:
Image
Inputs
):
def
get_image_feature
(
self
,
image_input
:
Multimodal
Inputs
):
pixel_values
=
image_input
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
1e86457c
...
...
@@ -38,7 +38,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
...
...
@@ -185,7 +185,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
self
.
post_init
()
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
)
->
List
[
int
]:
"""Pad input IDs with image tokens."""
# Get special token IDs
...
...
@@ -268,7 +268,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
get_input_embeddings
()
def
get_image_feature
(
self
,
image_input
:
Image
Inputs
):
def
get_image_feature
(
self
,
image_input
:
Multimodal
Inputs
):
"""
Projects the last hidden state from the vision model into language model space.
...
...
@@ -286,11 +286,11 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
def
embed_
image
_inputs
(
def
embed_
mm
_inputs
(
self
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
image_input
:
Image
Inputs
,
image_input
:
Multimodal
Inputs
,
)
->
torch
.
Tensor
:
if
input_ids
is
None
:
raise
ValueError
(
"Unimplemented"
)
...
...
@@ -401,10 +401,9 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
llm_input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image
_embedding_func
=
self
.
get_image_feature
,
mm_data
_embedding_func
=
self
.
get_image_feature
,
)
outputs
=
self
.
language_model
(
...
...
python/sglang/srt/models/llama.py
View file @
1e86457c
...
...
@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
...
...
python/sglang/srt/models/llava.py
View file @
1e86457c
...
...
@@ -31,7 +31,7 @@ from transformers import (
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
...
...
@@ -46,7 +46,7 @@ from sglang.srt.utils import add_prefix
class
LlavaBaseForCausalLM
(
nn
.
Module
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
# hardcode for spatial_unpad + anyres
...
...
@@ -134,7 +134,7 @@ class LlavaBaseForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
image_inputs
=
forward_batch
.
image
_inputs
image_inputs
=
forward_batch
.
mm
_inputs
if
forward_batch
.
forward_mode
.
is_extend
():
# Clamp input ids. This is because the input_ids for the image tokens are
...
...
python/sglang/srt/models/llavavid.py
View file @
1e86457c
...
...
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -57,7 +57,7 @@ class LlavaVidForCausalLM(nn.Module):
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
pad_values
=
image_inputs
.
pad_values
new_image_feature_len
=
self
.
image_feature_len
...
...
@@ -112,7 +112,7 @@ class LlavaVidForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
image_inputs
=
forward_batch
.
image
_inputs
image_inputs
=
forward_batch
.
mm
_inputs
if
forward_batch
.
forward_mode
.
is_extend
():
bs
=
forward_batch
.
batch_size
...
...
python/sglang/srt/models/minicpmo.py
0 → 100644
View file @
1e86457c
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/minicpmv.py
View file @
1e86457c
...
...
@@ -52,9 +52,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
embed_image_inputs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -862,24 +862,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
inputs_embeds
:
torch
.
Tensor
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
else
:
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
image_inputs
=
forward_batch
.
merge_image_inputs
()
inputs_embeds
=
embed_image_inputs
(
image_input
=
image_inputs
,
input_ids
=
input_ids
,
input_embedding
=
self
.
get_input_embeddings
(),
image_embedding_func
=
self
.
get_image_features
,
placeholder_token_ids
=
[
image_inputs
.
im_token_id
]
+
image_inputs
.
pad_values
,
)
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_features
,
)
hidden_states
=
self
.
llm
.
model
(
input_ids
=
None
,
...
...
@@ -925,7 +913,7 @@ class MiniCPMVBaseModel(nn.Module):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_image_features
(
self
,
image_inputs
:
Image
Inputs
)
->
torch
.
Tensor
:
def
get_image_features
(
self
,
image_inputs
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -1037,7 +1025,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
def
get_image_features
(
self
,
image_inputs
:
Image
Inputs
,
image_inputs
:
Multimodal
Inputs
,
)
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
image_inputs
.
pixel_values
...
...
@@ -1075,7 +1063,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return
self
.
resampler
(
vision_embedding
,
tgt_sizes
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
# Get all special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
:
int
=
image_inputs
.
im_end_id
...
...
python/sglang/srt/models/mllama.py
View file @
1e86457c
...
...
@@ -32,7 +32,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaDecoderLayer
,
LlamaMLP
...
...
@@ -796,7 +796,7 @@ class MllamaForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
capture_mode
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
pixel_values
=
image_inputs
.
pixel_values
pad_values
=
image_inputs
.
pad_values
...
...
@@ -815,7 +815,7 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images
=
max_num_tiles
=
bs
=
0
for
i
,
im
in
enumerate
(
forward_batch
.
image
_inputs
):
for
i
,
im
in
enumerate
(
forward_batch
.
mm
_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
im
is
not
None
:
max_num_images
=
max
(
max_num_images
,
im
.
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
im
.
pixel_values
.
shape
[
2
])
...
...
@@ -842,7 +842,7 @@ class MllamaForConditionalGeneration(nn.Module):
)
i
=
0
encoder_lens_need
=
[]
for
k
,
im
in
enumerate
(
forward_batch
.
image
_inputs
):
for
k
,
im
in
enumerate
(
forward_batch
.
mm
_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
im
is
None
:
continue
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
1e86457c
...
...
@@ -57,7 +57,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -513,7 +513,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Multimodal
Inputs
):
# Get all special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_end_id
:
int
=
image_inputs
.
im_end_id
...
...
@@ -523,7 +523,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
get_image_feature
(
self
,
image_input
:
Image
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
image_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
return
image_embeds
...
...
@@ -572,10 +572,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image
_embedding_func
=
self
.
get_image_feature
,
mm_data
_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
1e86457c
...
...
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -472,16 +472,16 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
Image
Inputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
multi_modal_inputs
:
Multimodal
Inputs
):
# Get all special token IDs
im_start_id
:
int
=
image
_inputs
.
im_start_id
im_end_id
:
int
=
image
_inputs
.
im_end_id
im_start_id
:
int
=
multi_modal
_inputs
.
im_start_id
im_end_id
:
int
=
multi_modal
_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image
_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
multi_modal
_inputs
)
def
get_image_feature
(
self
,
image_input
:
Image
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
image_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
return
image_embeds
...
...
@@ -530,10 +530,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image
_embedding_func
=
self
.
get_image_feature
,
mm_data
_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
...
...
python/sglang/srt/openai_api/adapter.py
View file @
1e86457c
...
...
@@ -899,6 +899,7 @@ def v1_chat_generate_request(
input_ids
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
audio_data_list
=
[]
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
...
...
@@ -912,6 +913,7 @@ def v1_chat_generate_request(
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
...
...
@@ -956,7 +958,7 @@ def v1_chat_generate_request(
)
except
:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compati
a
ble
# has a different tools input format that is not compatible
# with openAI's apply_chat_template tool_call format, like Mistral.
tools
=
[
t
if
"function"
in
t
else
{
"function"
:
t
}
for
t
in
tools
]
prompt_ids
=
tokenizer_manager
.
tokenizer
.
apply_chat_template
(
...
...
@@ -976,11 +978,13 @@ def v1_chat_generate_request(
prompt_ids
+=
encoded
stop
=
request
.
stop
image_data
=
None
audio_data
=
None
modalities
=
[]
else
:
conv
=
generate_chat_conv
(
request
,
chat_template_name
)
prompt
=
conv
.
get_prompt
()
image_data
=
conv
.
image_data
audio_data
=
conv
.
audio_data
modalities
=
conv
.
modalities
stop
=
conv
.
stop_str
or
[]
if
request
.
stop
:
...
...
@@ -994,6 +998,7 @@ def v1_chat_generate_request(
prompt_ids
=
request
.
messages
stop
=
request
.
stop
image_data
=
None
audio_data
=
None
modalities
=
[]
input_ids
.
append
(
prompt_ids
)
return_logprobs
.
append
(
request
.
logprobs
)
...
...
@@ -1034,6 +1039,7 @@ def v1_chat_generate_request(
sampling_params_list
.
append
(
sampling_params
)
image_data_list
.
append
(
image_data
)
audio_data_list
.
append
(
audio_data
)
modalities_list
.
append
(
modalities
)
if
len
(
all_requests
)
==
1
:
if
isinstance
(
input_ids
[
0
],
str
):
...
...
@@ -1042,6 +1048,7 @@ def v1_chat_generate_request(
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
image_data_list
=
image_data_list
[
0
]
audio_data_list
=
audio_data_list
[
0
]
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
...
...
@@ -1056,6 +1063,7 @@ def v1_chat_generate_request(
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
image_data
=
image_data_list
,
audio_data
=
audio_data_list
,
sampling_params
=
sampling_params_list
,
return_logprob
=
return_logprobs
,
logprob_start_len
=
logprob_start_lens
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
1e86457c
...
...
@@ -227,14 +227,25 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail
:
Optional
[
Literal
[
"auto"
,
"low"
,
"high"
]]
=
"auto"
class
ChatCompletionMessageContentAudioURL
(
BaseModel
):
url
:
str
class
ChatCompletionMessageContentImagePart
(
BaseModel
):
type
:
Literal
[
"image_url"
]
image_url
:
ChatCompletionMessageContentImageURL
modalities
:
Optional
[
Literal
[
"image"
,
"multi-images"
,
"video"
]]
=
"image"
class
ChatCompletionMessageContentAudioPart
(
BaseModel
):
type
:
Literal
[
"audio_url"
]
audio_url
:
ChatCompletionMessageContentAudioURL
ChatCompletionMessageContentPart
=
Union
[
ChatCompletionMessageContentTextPart
,
ChatCompletionMessageContentImagePart
ChatCompletionMessageContentTextPart
,
ChatCompletionMessageContentImagePart
,
ChatCompletionMessageContentAudioPart
,
]
...
...
python/sglang/srt/utils.py
View file @
1e86457c
...
...
@@ -55,14 +55,13 @@ import triton
import
zmq
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
packaging.version
import
Version
,
pars
e
from
PIL
import
Imag
e
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch.func
import
functional_call
from
torch.library
import
Library
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.utils._contextlib
import
_DecoratorContextManager
from
torch.utils.cpp_extension
import
CUDA_HOME
from
triton.runtime.cache
import
(
FileCacheManager
,
default_cache_dir
,
...
...
@@ -507,9 +506,37 @@ def decode_video_base64(video_base64):
)
# Return an empty array and size tuple if no frames were found
def
load_image
(
image_file
:
Union
[
str
,
bytes
]):
from
PIL
import
Image
def
load_audio
(
audio_file
:
str
,
sr
:
int
=
16000
,
mono
:
bool
=
True
)
->
np
.
ndarray
:
# Use soundfile here, since librosa use it under the hood,
# and librosa will not support audio loading in the future
import
soundfile
as
sf
from
scipy.signal
import
resample
# print(f"loading {audio_file}")
# Load audio data
if
isinstance
(
audio_file
,
bytes
):
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
audio_file
))
elif
audio_file
.
startswith
(
"data:"
):
audio_file
=
audio_file
.
split
(
","
)[
1
]
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
base64
.
b64decode
(
audio_file
)))
elif
isinstance
(
audio_file
,
str
):
audio
,
original_sr
=
sf
.
read
(
audio_file
)
else
:
raise
ValueError
(
f
"Invalid audio format:
{
audio_file
}
"
)
# Resample audio if the original sample rate is different from the desired sample rate
if
original_sr
!=
sr
:
num_samples
=
int
(
len
(
audio
)
*
float
(
sr
)
/
original_sr
)
audio
=
resample
(
audio
,
num_samples
)
# Convert to mono if requested and audio is stereo
if
mono
and
len
(
audio
.
shape
)
>
1
:
audio
=
np
.
mean
(
audio
,
axis
=
1
)
return
audio
def
load_image
(
image_file
:
Union
[
str
,
bytes
])
->
tuple
[
Image
,
tuple
[
int
,
int
]]:
image
=
image_size
=
None
if
isinstance
(
image_file
,
bytes
):
...
...
test/srt/test_vision_openai_server.py
View file @
1e86457c
...
...
@@ -87,7 +87,8 @@ class TestOpenAIVisionServer(unittest.TestCase):
# `driver` is for gemma-3-it
assert
"man"
in
text
or
"person"
or
"driver"
in
text
,
text
assert
"cab"
in
text
or
"taxi"
in
text
or
"SUV"
in
text
,
text
assert
"iron"
in
text
,
text
# MiniCPMO fails to recognize `iron`, but `hanging`
assert
"iron"
in
text
or
"hang"
in
text
,
text
assert
response
.
id
assert
response
.
created
assert
response
.
usage
.
prompt_tokens
>
0
...
...
@@ -177,7 +178,9 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert
response
.
choices
[
0
].
message
.
role
==
"assistant"
text
=
response
.
choices
[
0
].
message
.
content
assert
isinstance
(
text
,
str
)
print
(
f
"LLM response:
{
text
}
"
)
print
(
"-"
*
30
)
print
(
f
"Multi images response:
\n
{
text
}
"
)
print
(
"-"
*
30
)
assert
"man"
in
text
or
"cab"
in
text
or
"SUV"
in
text
or
"taxi"
in
text
,
text
assert
"logo"
in
text
or
'"S"'
in
text
or
"SG"
in
text
,
text
assert
response
.
id
...
...
@@ -272,21 +275,18 @@ class TestOpenAIVisionServer(unittest.TestCase):
# messages = self.prepare_video_messages_video_direct(file_path)
messages
=
self
.
prepare_video_messages
(
file_path
)
video_request
=
client
.
chat
.
completions
.
create
(
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
messages
,
temperature
=
0
,
max_tokens
=
1024
,
stream
=
Tru
e
,
stream
=
Fals
e
,
)
video_response
=
response
.
choices
[
0
].
message
.
content
print
(
"-"
*
30
)
video_response
=
""
for
chunk
in
video_request
:
if
chunk
.
choices
[
0
].
delta
.
content
is
not
None
:
content
=
chunk
.
choices
[
0
].
delta
.
content
video_response
+=
content
print
(
content
,
end
=
""
,
flush
=
True
)
print
(
f
"Video response:
\n
{
video_response
}
"
)
print
(
"-"
*
30
)
# Add assertions to validate the video response
...
...
@@ -308,6 +308,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
self
.
assertGreater
(
len
(
video_response
),
0
)
def
test_regex
(
self
):
return
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
regex
=
(
...
...
@@ -392,6 +393,77 @@ class TestOpenAIVisionServer(unittest.TestCase):
with
ThreadPoolExecutor
(
4
)
as
executor
:
list
(
executor
.
map
(
self
.
run_decode_with_image
,
image_ids
))
def
prepare_audio_messages
(
self
,
prompt
,
audio_file_name
):
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prompt
,
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
f
"
{
audio_file_name
}
"
},
},
],
}
]
return
messages
def
get_audio_response
(
self
,
url
:
str
,
prompt
,
category
):
audio_file_path
=
self
.
get_or_download_file
(
url
)
client
=
openai
.
Client
(
api_key
=
"sk-123456"
,
base_url
=
self
.
base_url
)
messages
=
self
.
prepare_audio_messages
(
prompt
,
audio_file_path
)
response
=
client
.
chat
.
completions
.
create
(
model
=
"default"
,
messages
=
messages
,
temperature
=
0
,
max_tokens
=
128
,
stream
=
False
,
)
audio_response
=
response
.
choices
[
0
].
message
.
content
print
(
"-"
*
30
)
print
(
f
"audio
{
category
}
response:
\n
{
audio_response
}
"
)
print
(
"-"
*
30
)
audio_response
=
audio_response
.
lower
()
self
.
assertIsNotNone
(
audio_response
)
self
.
assertGreater
(
len
(
audio_response
),
0
)
return
audio_response
def
_test_audio_speech_completion
(
self
):
# a fragment of Trump's speech
audio_response
=
self
.
get_audio_response
(
AUDIO_TRUMP_SPEECH_URL
,
"I have an audio sample. Please repeat the person's words"
,
category
=
"speech"
,
)
assert
"thank you"
in
audio_response
assert
"it's a privilege to be here"
in
audio_response
assert
"leader"
in
audio_response
assert
"science"
in
audio_response
assert
"art"
in
audio_response
def
_test_audio_ambient_completion
(
self
):
# bird song
audio_response
=
self
.
get_audio_response
(
AUDIO_BIRD_SONG_URL
,
"Please listen to the audio snippet carefully and transcribe the content."
,
"ambient"
,
)
assert
"bird"
in
audio_response
def
test_audio_chat_completion
(
self
):
pass
class
TestQwen2VLServer
(
TestOpenAIVisionServer
):
@
classmethod
...
...
@@ -535,6 +607,32 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
cls
.
base_url
+=
"/v1"
class
TestMinicpmoServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"openbmb/MiniCPM-o-2_6"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--chat-template"
,
"minicpmo"
,
"--mem-fraction-static"
,
"0.7"
,
"--tp=2"
,
],
)
cls
.
base_url
+=
"/v1"
def
test_audio_chat_completion
(
self
):
self
.
_test_audio_speech_completion
()
self
.
_test_audio_ambient_completion
()
class
TestDeepseekVL2Server
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
...
...
test/srt/test_vlm_accuracy.py
View file @
1e86457c
...
...
@@ -13,8 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_
image
_inputs
from
sglang.srt.managers.schedule_batch
import
Image
Inputs
from
sglang.srt.managers.mm_utils
import
embed_
mm
_inputs
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -136,7 +136,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
return
inputs
def
get_sglang_model
(
self
):
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
ModelConfig
(
self
.
model_path
,
model_override_args
=
"{}"
),
mem_fraction_static
=
0.8
,
gpu_id
=
0
,
...
...
@@ -148,7 +148,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
disable_cuda_graph
=
True
,
),
)
return
model_runner
.
model
return
self
.
model_runner
.
model
class
TestMiniCPMVLogits
(
VisionLLMLogitsBase
):
...
...
@@ -165,10 +165,13 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
cls
.
chat_template
=
"minicpmv"
cls
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
cls
.
model
=
AutoModel
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
,
trust_remote_code
=
True
).
eval
()
cls
.
model
.
to
(
cls
.
device
)
cls
.
hf_model
=
(
AutoModel
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
,
trust_remote_code
=
True
)
.
eval
()
.
to
(
cls
.
device
)
)
async
def
test_vlm_embedding_output
(
self
):
"""
...
...
@@ -184,7 +187,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
"pixel_values"
:
inputs
.
pixel_values
,
"tgt_sizes"
:
inputs
.
tgt_sizes
,
}
(
hf_output
,
_
)
=
self
.
model
.
get_vllm_embedding
(
(
hf_output
,
_
)
=
self
.
hf_
model
.
get_vllm_embedding
(
model_inputs
,
)
hf_output
=
hf_output
.
squeeze
(
0
)
...
...
@@ -192,14 +195,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang
model
=
self
.
get_sglang_model
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
sglang_output
=
embed_
image
_inputs
(
image
_input
=
Image
Inputs
(
sglang_output
=
embed_
mm
_inputs
(
mm
_input
=
Multimodal
Inputs
(
pixel_values
=
inputs
[
"pixel_values"
][
0
],
tgt_sizes
=
inputs
[
"tgt_sizes"
][
0
],
),
input_ids
=
input_ids
,
input_embedding
=
model
.
get_input_embeddings
(),
image
_embedding_func
=
model
.
get_image_features
,
mm_data
_embedding_func
=
model
.
get_image_features
,
placeholder_token_ids
=
[
self
.
processor
.
tokenizer
.
unk_token_id
,
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment