Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5cb552b1
".bazelignore" did not exist on "016b2e77691c5578f90e35d772fea80540f2b30e"
Unverified
Commit
5cb552b1
authored
Apr 01, 2025
by
Mick
Committed by
GitHub
Mar 31, 2025
Browse files
refactor: multimodal data (#4754)
parent
c7457191
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
425 additions
and
476 deletions
+425
-476
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+105
-104
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+14
-80
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+3
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+31
-19
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+16
-7
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+63
-147
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+17
-27
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+29
-14
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+9
-6
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+21
-31
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+20
-21
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+18
-6
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+40
-2
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+5
-7
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+31
-5
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
5cb552b1
...
...
@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
dp_size
=
get_attention_dp_size
()
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
5cb552b1
...
...
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
)
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternImageTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2ForCausalLM
...
...
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return
x
# todo
class
DeepseekVL2ForCausalLM
(
nn
.
Module
):
def
__init__
(
...
...
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
**
kwargs
:
object
,
):
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
contains_image_inputs
()
):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
for
idx
,
image
in
enumerate
(
forward_batch
.
mm_inputs
):
if
image
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
idx
]
end_idx
=
start_idx
+
extend_seq_lens_cpu
[
idx
]
images_emb_mask
=
image
.
images_emb_mask
.
to
(
device
=
"cuda"
)
image_features
=
self
.
get_image_feature
(
image
)
input_embeds
[
start_idx
:
end_idx
]
=
input_embeds
[
start_idx
:
end_idx
].
masked_scatter
(
images_emb_mask
.
unsqueeze
(
-
1
),
image_features
)
outputs
=
self
.
language_model
.
forward
(
hs
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
image_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
language_model
,
)
return
output
s
return
h
s
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
@@ -263,21 +249,34 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
return
input_ids
helper
=
MultiModalityDataPaddingPatternImageTokens
(
image_token_id
=
image_inputs
.
im_token_id
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
]):
images_spatial_crop
=
torch
.
cat
(
[
item
.
image_spatial_crop
for
item
in
items
],
dim
=
0
)
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
):
pixel_values
=
image_input
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
image_feature
=
self
.
vision
.
forward_features
(
pixel_values
)
assert
images_spatial_crop
.
dim
()
==
3
# TODO: can it be batched ?
images_in_this_batch
=
[]
for
item
in
items
:
assert
item
.
pixel_values
.
dim
()
==
4
image_feature
=
self
.
vision
.
forward_features
(
item
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
)
images_embeds
=
self
.
projector
(
image_feature
)
_
,
hw
,
n_dim
=
images_embeds
.
shape
h
=
w
=
int
(
hw
**
0.5
)
tile_index
=
0
images_in_this_batch
=
[]
images_spatial_crop
=
image_input
.
image_spatial_crop
for
jdx
in
range
(
images_spatial_crop
.
shape
[
1
]):
num_width_tiles
,
num_height_tiles
=
images_spatial_crop
[
0
,
jdx
]
for
jdx
in
range
(
item
.
image_spatial_crop
.
shape
[
1
]):
num_width_tiles
,
num_height_tiles
=
item
.
image_spatial_crop
[
0
,
jdx
]
if
num_width_tiles
==
0
or
num_height_tiles
==
0
:
break
num_tiles_in_image
=
num_width_tiles
*
num_height_tiles
...
...
@@ -300,7 +299,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
new_lines_in_global
=
repeat
(
self
.
image_newline
,
"d -> h 1 d"
,
h
=
h
)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features
=
torch
.
cat
([
global_features
,
new_lines_in_global
],
dim
=
1
)
global_features
=
torch
.
cat
(
[
global_features
,
new_lines_in_global
],
dim
=
1
)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features
=
global_features
.
view
(
-
1
,
n_dim
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
5cb552b1
...
...
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import
torch
from
torch
import
nn
from
transformers
import
(
AutoModel
,
BatchFeature
,
Gemma3Config
,
Gemma3Processor
,
PreTrainedModel
,
)
from
transformers.models.gemma3.processing_gemma3
import
Gemma3ProcessorKwargs
from
transformers
import
AutoModel
,
Gemma3Config
,
PreTrainedModel
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
...
...
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
flatten_nested_list
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
...
...
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"""
return
self
.
language_model
.
get_attention_sliding_window_size
()
def
get_image_feature
(
self
,
i
mage_input
:
Multimodal
Inputs
):
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values
=
image_input
.
pixel_values
pixel_values
=
torch
.
stack
(
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
]),
dim
=
0
)
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
...
...
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
def
embed_mm_inputs
(
self
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
image_input
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
if
input_ids
is
None
:
raise
ValueError
(
"Unimplemented"
)
# boolean-masking image tokens
special_image_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
image_input
.
pad_values
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
inputs_embeds
=
None
if
num_image_tokens_in_input_ids
==
0
:
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
return
inputs_embeds
else
:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features
=
self
.
get_image_feature
(
image_input
.
pixel_values
)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding
=
(
image_features
.
shape
[
0
]
*
image_features
.
shape
[
1
]
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
num_image
=
num_image_tokens_in_input_ids
//
image_features
.
shape
[
1
]
image_features
=
image_features
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
image_features
=
image_features
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
image_features
)
return
inputs_embeds
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else
:
llm_input_ids
=
input_ids
inputs_embed
s
=
general_mm_embed_routine
(
h
s
=
general_mm_embed_routine
(
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
outputs
=
self
.
language_model
(
input_ids
=
None
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
output
s
return
h
s
def
tie_weights
(
self
):
return
self
.
language_model
.
tie_weights
()
...
...
python/sglang/srt/models/llama.py
View file @
5cb552b1
...
...
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
...
...
python/sglang/srt/models/llava.py
View file @
5cb552b1
...
...
@@ -31,7 +31,7 @@ from transformers import (
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalInputs
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
...
...
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
class
LlavaBaseForCausalLM
(
nn
.
Module
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
image_sizes
=
flatten_nested_list
(
[
item
.
image_sizes
for
item
in
image_inputs
.
mm_items
]
)
pad_values
=
[
item
.
pad_value
for
item
in
image_inputs
.
mm_items
]
# hardcode for spatial_unpad + anyres
if
image_inputs
.
modalities
is
not
None
and
(
"multi-images"
in
image_inputs
.
modalities
or
"video"
in
image_inputs
.
m
odalitie
s
if
any
(
item
.
modality
==
Modality
.
MULTI_IMAGES
or
item
.
modality
==
Modality
.
VIDEO
f
or
item
in
image_inputs
.
m
m_item
s
):
image_aspect_ratio
=
"pad"
else
:
...
...
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math
.
ceil
(
self
.
image_size
/
self
.
patch_size
/
2
)
**
2
)
else
:
new_image_feature_len
=
self
.
image_feature_len
# multiimage
new_image_feature_len
=
self
.
image_feature_len
# multi
-
image
height
=
width
=
self
.
num_patches_per_side
if
"anyres"
in
image_aspect_ratio
:
...
...
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids
=
(
input_ids
[:
offset
]
+
[
pad_values
[
image_idx
]]
*
new_image_feature_len
+
[
pad_values
[
image_idx
%
len
(
pad_values
)
]]
*
new_image_feature_len
+
input_ids
[
offset
+
1
:]
)
offset_list
.
append
(
offset
)
...
...
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list
=
[]
max_image_offset
=
[]
for
im
in
image_inputs
:
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
i
m
.
modalit
ies
)
if
im
:
modalities_list
.
extend
(
[
ite
m
.
modalit
y
for
item
in
im
.
mm_items
]
)
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
np
.
max
(
np
.
array
(
im
.
image_offsets
)
+
np
.
array
(
im
.
image_pad_len
))
...
...
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if
need_vision
.
any
():
bs
=
forward_batch
.
batch_size
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
)
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
flatten_nested_list
(
[
item
.
image_sizes
for
item
in
image_inputs
[
i
].
mm_items
]
)
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features
=
[]
height
=
width
=
self
.
num_patches_per_side
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
modalities_list
[
image_idx
]
==
"image"
:
if
modalities_list
[
image_idx
]
==
Modality
.
IMAGE
:
image_aspect_ratio
=
(
self
.
config
.
image_aspect_ratio
)
# single image
elif
(
modalities_list
[
image_idx
]
==
"multi-images"
or
modalities_list
[
image_idx
]
==
"video"
modalities_list
[
image_idx
]
==
Modality
.
MULTI_IMAGES
or
modalities_list
[
image_idx
]
==
Modality
.
VIDEO
):
image_aspect_ratio
=
"pad"
# multi image
# image_aspect_ratio = (
...
...
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if
(
image_feature
.
shape
[
0
]
>
1
and
"anyres"
in
image_aspect_ratio
and
modalities_list
[
image_idx
]
==
"image"
and
modalities_list
[
image_idx
]
==
Modality
.
IMAGE
):
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
...
...
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature
=
image_feature
.
unsqueeze
(
0
)
else
:
if
modalities_list
[
image_idx
]
==
"video"
:
# video
if
modalities_list
[
image_idx
]
==
Modality
.
VIDEO
:
# video
# 2x2 pooling
num_of_frames
=
image_feature
.
shape
[
0
]
image_feature
=
image_feature
.
view
(
...
...
python/sglang/srt/models/llavavid.py
View file @
5cb552b1
...
...
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
,
flatten_nested_list
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
pad_values
=
image_inputs
.
pad_values
pad_values
=
[
item
.
pad_value
for
item
in
image_inputs
.
mm_items
]
new_image_feature_len
=
self
.
image_feature_len
pad_ids
=
pad_values
*
(
...
...
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
)
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
flatten_nested_list
(
[
item
.
image_offsets
for
item
in
image_inputs
[
i
].
mm_items
]
)
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_resampler.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.vision_resampler.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline"
:
"language_model.model.image_newline"
,
}
params_dict
=
dict
(
self
.
named_parameters
())
...
...
python/sglang/srt/models/minicpmo.py
View file @
5cb552b1
...
...
@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
embed_mm_inputs
,
get_multimodal_data_bounds
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
flatten_nested_list
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.minicpmv
import
(
Idefics2VisionTransformer
,
MiniCPM
V
BaseModel
,
MiniCPMBaseModel
,
Resampler2_5
,
)
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
...
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
return
hidden_states
class
MiniCPMO
(
MiniCPM
V
BaseModel
):
class
MiniCPMO
(
MiniCPMBaseModel
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
return
input_lengths_after_cnn
,
input_lengths_after_pooling
def
get_audio_embedding_streaming
(
self
,
multimodal_input
:
Multimodal
Inputs
):
def
get_audio_embedding_streaming
(
self
,
items
:
List
[
Multimodal
DataItem
]
):
r
"""
Extract audio embeddings in a streaming manner using cached key-value pairs.
...
...
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
wavforms
=
(
[]
if
multimodal_input
.
audio_features
is
None
else
multimodal_input
.
audio_features
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
(
[]
if
multimodal_input
.
audio_feature_lens
is
None
else
multimodal_input
.
audio_feature_lens
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
)
# exist audio
...
...
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
ret
[
i
,
start
:
ending
]
=
True
return
ret
def
get_audio_embedding
(
self
,
multimodal_input
:
Multimodal
Inputs
,
chunk_length
=-
1
):
def
get_audio_embedding
(
self
,
items
:
List
[
Multimodal
DataItem
]
,
chunk_length
=-
1
):
r
"""
Extract full audio embeddings with optional chunk-based attention.
...
...
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
not use key-value caching and is suitable for non-streaming inference.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
attention (>0) during embedding computation.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms
=
(
[]
if
multimodal_input
.
audio_features
is
None
else
multimodal_input
.
audio_features
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
(
[]
if
multimodal_input
.
audio_feature_lens
is
None
else
multimodal_input
.
audio_feature_lens
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
)
final_audio_embeds
=
[]
assert
isinstance
(
wavforms
,
list
)
assert
isinstance
(
wavforms
[
0
],
torch
.
Tensor
)
# exist audio
for
wavform
in
wavforms
:
if
len
(
wavform
)
>
0
:
...
...
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
final_audio_embeds
.
append
(
target_audio_embeds
)
return
final_audio_embeds
def
get_audio_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
embedding
=
self
.
get_omni_embedding
(
items
=
items
,
chunk_length
=
self
.
config
.
audio_chunk_length
,
stream_input
=
False
,
)
return
embedding
def
get_omni_embedding
(
self
,
input_ids
,
multimodal_input
:
MultimodalInputs
,
input_embeds
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
items
:
List
[
MultimodalDataItem
],
chunk_length
=-
1
,
stream_input
=
False
,
):
"""
Args:
multimodal_input:
input_embeds:
chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding
Returns:
final embeddings with audio feature
"""
input_embeds
=
input_embeds
.
unsqueeze
(
0
)
if
not
forward_mode
.
is_decode
()
and
multimodal_input
.
contains_audio_inputs
():
audio_bounds
=
get_multimodal_data_bounds
(
input_ids
=
input_ids
,
pad_values
=
multimodal_input
.
pad_values
,
token_pairs
=
[
(
multimodal_input
.
audio_start_id
,
multimodal_input
.
audio_end_id
)
],
)
if
audio_bounds
.
numel
()
==
0
:
input_embeds
=
input_embeds
.
squeeze
(
0
)
# TODO
logger
.
warn
(
"Unimplemented logic. Please try disabling chunked prefill"
)
return
input_embeds
audio_bounds
=
audio_bounds
.
unsqueeze
(
0
)
bs
=
len
(
input_embeds
)
if
stream_input
:
audio_embeddings
=
self
.
get_audio_embedding_streaming
(
multimodal_input
)
audio_embeddings
=
self
.
get_audio_embedding_streaming
(
items
)
else
:
audio_embeddings
=
self
.
get_audio_embedding
(
multimodal_input
,
chunk_length
)
audio_embeddings
=
self
.
get_audio_embedding
(
items
,
chunk_length
)
bs
=
len
(
audio_embeddings
)
# batch size
assert
len
(
audio_embeddings
)
==
len
(
input_embeds
)
if
len
(
audio_embeddings
)
>
0
:
if
self
.
config
.
chunk_input
:
for
i
in
range
(
bs
):
audio_embs
=
torch
.
cat
(
audio_embeddings
[
i
],
dim
=
0
).
to
(
device
=
input_embeds
.
device
,
dtype
=
input_embeds
.
dtype
)
audio_start_pos
=
0
for
bound
in
audio_bounds
[
i
]:
audio_len
=
bound
[
1
]
-
bound
[
0
]
+
1
input_embeds
[
0
,
bound
[
0
]
:
bound
[
1
]
+
1
]
=
audio_embs
[
audio_start_pos
:
audio_start_pos
+
audio_len
,
:
]
audio_start_pos
+=
audio_len
else
:
for
i
in
range
(
bs
):
audio_embs
=
audio_embeddings
[
i
]
bounds
=
audio_bounds
[
i
]
for
embs
,
bound
in
zip
(
audio_embs
,
bounds
):
audio_indices
=
torch
.
arange
(
bound
[
0
],
bound
[
1
],
dtype
=
torch
.
long
).
to
(
input_embeds
.
device
)
if
embs
.
shape
[
0
]
!=
len
(
audio_indices
):
raise
ValueError
(
f
"Shape mismatch: Trying to assign embeddings of shape
{
embs
.
shape
}
"
f
"to input indices of length
{
len
(
audio_indices
)
}
"
audio_embs
=
torch
.
cat
(
flatten_nested_list
(
audio_embeddings
),
dim
=
0
)
return
audio_embs
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
input_embeds
[
i
,
audio_indices
]
=
embs
.
to
(
input_embeds
.
dtype
)
input_embeds
=
input_embeds
.
squeeze
(
0
)
return
input_embeds
assert
len
(
pixel_values
)
==
tgt_sizes
.
shape
[
0
]
def
get_image_features
(
self
,
image_inputs
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
pixel_values
=
image_inputs
.
pixel_values
tgt_sizes
=
image_inputs
.
tgt_sizes
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
all_pixel_values_lst
=
[
...
...
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
(
...
...
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
inputs_embeds
=
None
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
forward_batch
.
contains_image_inputs
()
):
mm_inputs
=
forward_batch
.
merge_mm_inputs
()
inputs_embeds
=
embed_mm_inputs
(
mm_input
=
mm_inputs
,
input_ids
=
input_ids
,
input_embedding
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_features
,
placeholder_token_ids
=
[
mm_inputs
.
im_token_id
]
+
mm_inputs
.
pad_values
,
)
input_ids
=
input_ids
.
clamp
(
min
=
0
,
max
=
self
.
get_input_embeddings
().
num_embeddings
-
1
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
self
.
config
.
init_audio
and
forward_batch
.
contains_audio_inputs
()
):
mm_input
=
forward_batch
.
merge_mm_inputs
()
inputs_embeds
=
self
.
get_omni_embedding
(
input_ids
=
input_ids
,
multimodal_input
=
mm_input
,
input_embeds
=
inputs_embeds
,
forward_mode
=
forward_batch
.
forward_mode
,
chunk_length
=
self
.
config
.
audio_chunk_length
,
stream_input
=
False
,
placeholder_token_ids
=
(
([
mm_input
.
im_token_id
]
+
[
item
.
pad_value
for
item
in
mm_input
.
mm_items
])
if
forward_batch
.
contains_mm_inputs
()
else
[]
)
forward_batch
.
mm_inputs
=
None
hidden_states
=
self
.
llm
.
model
(
input_ids
=
None
,
positions
=
positions
,
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
llm
.
lm_head
,
forward_batch
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
placeholder_token_ids
=
placeholder_token_ids
,
positions
=
positions
,
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/minicpmv.py
View file @
5cb552b1
...
...
@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Config
,
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
RawImageType
=
Union
[
Image
.
Image
,
torch
.
Tensor
]
...
...
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return
tuple
(
int
(
x
)
for
x
in
version_str
.
split
(
"."
))
class
MiniCPM
V
BaseModel
(
nn
.
Module
):
class
MiniCPMBaseModel
(
nn
.
Module
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
...
...
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
return
vlm_embedding
,
vision_hidden_states
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
llm
.
get_input_embedding
()
return
self
.
llm
.
get_input_embedding
s
()
def
forward
(
self
,
...
...
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
inputs_embed
s
=
general_mm_embed_routine
(
hidden_state
s
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_features
,
)
hidden_states
=
self
.
llm
.
model
(
input_ids
=
None
,
image_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
llm
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
llm
.
lm_head
,
forward_batch
)
return
hidden_states
def
init_llm
(
self
,
...
...
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_image_feature
s
(
self
,
i
mage_inputs
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
MiniCPMV2_6
(
MiniCPM
V
BaseModel
):
class
MiniCPMV2_6
(
MiniCPMBaseModel
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return
vision_embedding
def
get_image_features
(
self
,
image_inputs
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
image_inputs
.
pixel_values
tgt_sizes
=
image_inputs
.
tgt_sizes
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
assert
len
(
pixel_values
)
==
tgt_sizes
.
shape
[
0
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
...
...
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
(
...
...
python/sglang/srt/models/mllama.py
View file @
5cb552b1
...
...
@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
capture_mode
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
pixel_values
=
image_inputs
.
pixel_values
pad_values
=
image_inputs
.
pad_values
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_inputs
.
mm_items
],
dim
=
0
)
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
num_concurrent_media
,
num_tiles
=
pixel_values
.
shape
[
1
:
3
]
num_patches
=
self
.
vision_model
.
num_patches
image_len
=
num_concurrent_media
*
num_tiles
*
num_patches
image
_inputs
.
num_image_tokens
=
image_len
mm
_inputs
.
num_image_tokens
=
image_len
pad_ids
=
pad_values
*
((
image_len
+
len
(
pad_values
))
//
len
(
pad_values
))
...
...
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images
=
max_num_tiles
=
bs
=
0
for
i
,
im
in
enumerate
(
forward_batch
.
mm_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
im
is
not
None
:
max_num_images
=
max
(
max_num_images
,
im
.
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
im
.
pixel_values
.
shape
[
2
])
for
i
,
mm_input
in
enumerate
(
forward_batch
.
mm_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
mm_input
is
not
None
:
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images
=
max
(
max_num_images
,
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
pixel_values
.
shape
[
2
])
bs
+=
1
if
max_num_images
*
max_num_tiles
*
bs
==
0
:
...
...
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
)
i
=
0
encoder_lens_need
=
[]
for
k
,
im
in
enumerate
(
forward_batch
.
mm_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
im
is
None
:
for
k
,
mm_input
in
enumerate
(
forward_batch
.
mm_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
mm_input
is
None
:
continue
encoder_lens_need
.
append
(
forward_batch
.
encoder_lens
[
k
])
for
j
in
range
(
im
.
pixel_values
.
shape
[
1
]):
img
=
im
.
pixel_values
[
0
,
j
]
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
for
j
in
range
(
pixel_values
.
shape
[
1
]):
img
=
pixel_values
[
0
,
j
]
num_tiles
=
img
.
shape
[
0
]
batched_images
[
i
,
j
,
:
num_tiles
]
=
img
batched_ar_ids
[
i
,
j
]
=
im
.
aspect_ratio_ids
[
0
,
j
]
batched_ar_mask
[
i
,
j
,
:
num_tiles
]
=
im
.
aspect_ratio_mask
[
0
,
j
]
batched_ar_ids
[
i
,
j
]
=
mm_input
.
mm_items
[
0
].
aspect_ratio_id
[
0
,
j
]
batched_ar_mask
[
i
,
j
,
:
num_tiles
]
=
mm_input
.
mm_items
[
0
].
aspect_ratio_mask
[
0
,
j
]
i
+=
1
return
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
...
...
python/sglang/srt/models/qwen2.py
View file @
5cb552b1
...
...
@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embedding
s
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
.
config
,
"scale_emb"
):
return
self
.
embed_tokens
(
input_ids
)
*
self
.
config
.
scale_emb
return
self
.
get_input_embeddings
()
(
input_ids
)
*
self
.
config
.
scale_emb
else
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
get_input_embeddings
()(
input_ids
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
embed_tokens
def
forward
(
self
,
...
...
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
get_input_embedding
s
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
s
(
input_ids
)
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
(
input_ids
)
def
get_input_embedding
(
self
)
->
nn
.
Embedding
:
def
get_input_embedding
s
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
5cb552b1
...
...
@@ -30,22 +30,13 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
AutoModel
,
Qwen2VLConfig
from
transformers
import
Qwen2VLConfig
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2.modeling_qwen2
import
Qwen2RMSNorm
from
transformers.models.qwen2_5_vl
import
Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
)
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
(
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
...
...
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image
_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm
_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
image
_inputs
.
im_start_id
im_end_id
:
int
=
image
_inputs
.
im_end_id
im_start_id
:
int
=
mm
_inputs
.
im_start_id
im_end_id
:
int
=
mm
_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thws
=
torch
.
concat
([
item
.
image_grid_thws
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thws
.
dim
()
==
2
,
image_grid_thws
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thws
)
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embed
s
=
general_mm_embed_routine
(
hidden_state
s
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
input_ids
=
None
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
if
not
get_embedding
:
...
...
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"
.
qkv_proj"
,
"
.
q_proj"
,
"q"
),
(
"
.
qkv_proj"
,
"
.
k_proj"
,
"k"
),
(
"
.
qkv_proj"
,
"
.
v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
5cb552b1
...
...
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
m
ulti_modal
_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
m
m
_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
m
ulti_modal
_inputs
.
im_start_id
im_end_id
:
int
=
m
ulti_modal
_inputs
.
im_end_id
im_start_id
:
int
=
m
m
_inputs
.
im_start_id
im_end_id
:
int
=
m
m
_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
m
ulti_modal
_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
m
m
_inputs
)
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thws
=
torch
.
concat
([
item
.
image_grid_thws
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thws
.
dim
()
==
2
,
image_grid_thws
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thws
)
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embeds
=
general_mm_embed_routine
(
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
input_ids
=
None
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
if
not
get_embedding
:
if
get_embedding
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/openai_api/adapter.py
View file @
5cb552b1
...
...
@@ -897,6 +897,7 @@ def v1_chat_generate_request(
request_ids
:
List
[
str
]
=
None
,
):
input_ids
=
[]
prompts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
audio_data_list
=
[]
...
...
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag
=
None
prompt
=
""
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
tools
=
None
...
...
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
image_data
=
None
audio_data
=
None
modalities
=
[]
prompt
=
request
.
messages
input_ids
.
append
(
prompt_ids
)
return_logprobs
.
append
(
request
.
logprobs
)
logprob_start_lens
.
append
(
-
1
)
top_logprobs_nums
.
append
(
request
.
top_logprobs
or
0
)
lora_paths
.
append
(
request
.
lora_path
)
prompts
.
append
(
prompt
)
sampling_params
=
{
"temperature"
:
request
.
temperature
,
...
...
@@ -1063,6 +1067,10 @@ def v1_chat_generate_request(
audio_data_list
.
append
(
audio_data
)
modalities_list
.
append
(
modalities
)
if
len
(
all_requests
)
==
1
:
if
tokenizer_manager
.
model_config
.
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
else
:
...
...
@@ -1075,6 +1083,10 @@ def v1_chat_generate_request(
top_logprobs_nums
=
top_logprobs_nums
[
0
]
modalities_list
=
modalities_list
[
0
]
lora_paths
=
lora_paths
[
0
]
else
:
if
tokenizer_manager
.
model_config
.
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
...
...
python/sglang/srt/utils.py
View file @
5cb552b1
...
...
@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import
base64
import
builtins
import
ctypes
...
...
@@ -54,6 +53,7 @@ import torch.distributed
import
torch.distributed
as
dist
import
triton
import
zmq
from
decord
import
VideoReader
,
cpu
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
PIL
import
Image
...
...
@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
import
soundfile
as
sf
from
scipy.signal
import
resample
# print(f"loading {audio_file}")
# Load audio data
if
isinstance
(
audio_file
,
bytes
):
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
audio_file
))
elif
audio_file
.
startswith
(
"data:"
):
audio_file
=
audio_file
.
split
(
","
)[
1
]
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
base64
.
b64decode
(
audio_file
)))
elif
audio_file
.
startswith
(
"http://"
)
or
audio_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"5"
))
response
=
requests
.
get
(
audio_file
,
stream
=
True
,
timeout
=
timeout
)
audio_file
=
BytesIO
(
response
.
content
)
response
.
close
()
audio
,
original_sr
=
sf
.
read
(
audio_file
)
elif
isinstance
(
audio_file
,
str
):
audio
,
original_sr
=
sf
.
read
(
audio_file
)
else
:
...
...
@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return
audio
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_image
(
image_file
:
Union
[
str
,
bytes
])
->
tuple
[
Image
,
tuple
[
int
,
int
]]:
image
=
image_size
=
None
...
...
@@ -1796,3 +1825,12 @@ def retry(
traceback
.
print_exc
()
time
.
sleep
(
delay
)
def
flatten_nested_list
(
nested_list
):
if
isinstance
(
nested_list
,
list
):
return
[
item
for
sublist
in
nested_list
for
item
in
flatten_nested_list
(
sublist
)
]
else
:
return
[
nested_list
]
test/srt/test_vision_openai_server.py
View file @
5cb552b1
...
...
@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"modalities"
:
"multi-images"
,
},
{
...
...
@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prompt
,
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
f
"
{
audio_file_name
}
"
},
},
{
"type"
:
"text"
,
"text"
:
prompt
,
},
],
}
]
...
...
test/srt/test_vlm_accuracy.py
View file @
5cb552b1
...
...
@@ -3,6 +3,7 @@
import
unittest
from
io
import
BytesIO
from
typing
import
List
import
numpy
as
np
import
requests
...
...
@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang
model
=
self
.
get_sglang_model
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
pixel_values
=
inputs
[
"pixel_values"
]
tgt_sizes
=
inputs
[
"tgt_sizes"
]
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
# per image
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
pixel_values_flat
+=
[
pixel_n
]
tgt_sizes_flat
+=
[
tgt_n
]
sglang_output
=
embed_mm_inputs
(
mm_input
=
MultimodalInputs
(
pixel_values
=
inputs
[
"pixel_values"
][
0
],
tgt_sizes
=
inputs
[
"tgt_sizes"
][
0
],
mm_inputs
=
MultimodalInputs
(
mm_items
=
[
MultimodalDataItem
(
pixel_values
=
pixel_values_flat
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
pad_value
=
self
.
processor
.
tokenizer
.
unk_token_id
,
)
]
),
input_ids
=
input_ids
,
input_embedding
=
model
.
get_input_embeddings
(),
mm
_data_embedding_func
=
model
.
get_image_feature
s
,
image
_data_embedding_func
=
model
.
get_image_feature
,
placeholder_token_ids
=
[
self
.
processor
.
tokenizer
.
unk_token_id
,
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment