Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5cb552b1
Unverified
Commit
5cb552b1
authored
Apr 01, 2025
by
Mick
Committed by
GitHub
Mar 31, 2025
Browse files
refactor: multimodal data (#4754)
parent
c7457191
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
425 additions
and
476 deletions
+425
-476
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+105
-104
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+14
-80
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+3
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+31
-19
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+16
-7
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+63
-147
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+17
-27
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+29
-14
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+9
-6
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+21
-31
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+20
-21
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+18
-6
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+40
-2
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+5
-7
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+31
-5
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
5cb552b1
...
...
@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
dp_size
=
get_attention_dp_size
()
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
5cb552b1
...
...
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
)
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternImageTokens
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2ForCausalLM
...
...
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return
x
# todo
class
DeepseekVL2ForCausalLM
(
nn
.
Module
):
def
__init__
(
...
...
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
**
kwargs
:
object
,
):
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
contains_image_inputs
()
):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
for
idx
,
image
in
enumerate
(
forward_batch
.
mm_inputs
):
if
image
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
idx
]
end_idx
=
start_idx
+
extend_seq_lens_cpu
[
idx
]
images_emb_mask
=
image
.
images_emb_mask
.
to
(
device
=
"cuda"
)
image_features
=
self
.
get_image_feature
(
image
)
input_embeds
[
start_idx
:
end_idx
]
=
input_embeds
[
start_idx
:
end_idx
].
masked_scatter
(
images_emb_mask
.
unsqueeze
(
-
1
),
image_features
)
outputs
=
self
.
language_model
.
forward
(
hs
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
input_embeds
,
image_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
language_model
,
)
return
output
s
return
h
s
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
return
input_ids
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
):
pixel_values
=
image_input
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
image_feature
=
self
.
vision
.
forward_features
(
pixel_values
)
images_embeds
=
self
.
projector
(
image_feature
)
_
,
hw
,
n_dim
=
images_embeds
.
shape
h
=
w
=
int
(
hw
**
0.5
)
tile_index
=
0
helper
=
MultiModalityDataPaddingPatternImageTokens
(
image_token_id
=
image_inputs
.
im_token_id
)
return
helper
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
]):
images_spatial_crop
=
torch
.
cat
(
[
item
.
image_spatial_crop
for
item
in
items
],
dim
=
0
)
assert
images_spatial_crop
.
dim
()
==
3
# TODO: can it be batched ?
images_in_this_batch
=
[]
images_spatial_crop
=
image_input
.
image_spatial_crop
for
jdx
in
range
(
images_spatial_crop
.
shape
[
1
]):
num_width_tiles
,
num_height_tiles
=
images_spatial_crop
[
0
,
jdx
]
if
num_width_tiles
==
0
or
num_height_tiles
==
0
:
break
num_tiles_in_image
=
num_width_tiles
*
num_height_tiles
# [hw, D]
global_features
=
images_embeds
[
tile_index
]
# [num_height_tiles * num_width_tiles, hw, D]
local_features
=
images_embeds
[
tile_index
+
1
:
tile_index
+
1
+
num_tiles_in_image
]
tile_index
+=
num_tiles_in_image
+
1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features
=
global_features
.
view
(
h
,
w
,
n_dim
)
# [D] -> [h, 1, D]
new_lines_in_global
=
repeat
(
self
.
image_newline
,
"d -> h 1 d"
,
h
=
h
)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features
=
torch
.
cat
([
global_features
,
new_lines_in_global
],
dim
=
1
)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features
=
global_features
.
view
(
-
1
,
n_dim
)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features
=
rearrange
(
local_features
,
"(th tw) (h w) d -> (th h) (tw w) d"
,
th
=
num_height_tiles
,
tw
=
num_width_tiles
,
h
=
h
,
w
=
w
,
for
item
in
items
:
assert
item
.
pixel_values
.
dim
()
==
4
image_feature
=
self
.
vision
.
forward_features
(
item
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
)
images_embeds
=
self
.
projector
(
image_feature
)
_
,
hw
,
n_dim
=
images_embeds
.
shape
h
=
w
=
int
(
hw
**
0.5
)
tile_index
=
0
for
jdx
in
range
(
item
.
image_spatial_crop
.
shape
[
1
]):
num_width_tiles
,
num_height_tiles
=
item
.
image_spatial_crop
[
0
,
jdx
]
if
num_width_tiles
==
0
or
num_height_tiles
==
0
:
break
num_tiles_in_image
=
num_width_tiles
*
num_height_tiles
# [hw, D]
global_features
=
images_embeds
[
tile_index
]
# [num_height_tiles * num_width_tiles, hw, D]
local_features
=
images_embeds
[
tile_index
+
1
:
tile_index
+
1
+
num_tiles_in_image
]
tile_index
+=
num_tiles_in_image
+
1
# [D] -> [num_height_tiles * h, 1, D]
new
_
line
s_in_local
=
repeat
(
self
.
image_newline
,
"d -> (th h) 1 d"
,
th
=
num_height_tiles
,
h
=
h
,
)
# format global and local features
# ----------------- global view add
newline
-----------------
# [hw, D] -> [h, w, D]
global_features
=
global_features
.
view
(
h
,
w
,
n_dim
)
# [D] -> [h, 1, D]
new_lines_in_global
=
repeat
(
self
.
image_newline
,
"d -> h 1 d"
,
h
=
h
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features
=
torch
.
cat
([
local_features
,
new_lines_in_local
],
dim
=
1
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features
=
local_features
.
view
(
-
1
,
n_dim
)
# merge global and local tiles
if
self
.
global_view_pos
==
"head"
:
global_local_features
=
torch
.
cat
(
[
global_features
,
self
.
view_seperator
[
None
,
:],
local_features
,
]
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features
=
torch
.
cat
(
[
global_features
,
new_lines_in_global
],
dim
=
1
)
else
:
global_local_features
=
torch
.
cat
(
[
local_features
,
self
.
view_seperator
[
None
,
:],
global_features
,
]
# [h, w + 1, D] -> [h * (w + 1), D]
global_features
=
global_features
.
view
(
-
1
,
n_dim
)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features
=
rearrange
(
local_features
,
"(th tw) (h w) d -> (th h) (tw w) d"
,
th
=
num_height_tiles
,
tw
=
num_width_tiles
,
h
=
h
,
w
=
w
,
)
images_in_this_batch
.
append
(
global_local_features
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local
=
repeat
(
self
.
image_newline
,
"d -> (th h) 1 d"
,
th
=
num_height_tiles
,
h
=
h
,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features
=
torch
.
cat
([
local_features
,
new_lines_in_local
],
dim
=
1
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features
=
local_features
.
view
(
-
1
,
n_dim
)
# merge global and local tiles
if
self
.
global_view_pos
==
"head"
:
global_local_features
=
torch
.
cat
(
[
global_features
,
self
.
view_seperator
[
None
,
:],
local_features
,
]
)
else
:
global_local_features
=
torch
.
cat
(
[
local_features
,
self
.
view_seperator
[
None
,
:],
global_features
,
]
)
images_in_this_batch
.
append
(
global_local_features
)
return
torch
.
cat
(
images_in_this_batch
,
dim
=
0
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
5cb552b1
...
...
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import
torch
from
torch
import
nn
from
transformers
import
(
AutoModel
,
BatchFeature
,
Gemma3Config
,
Gemma3Processor
,
PreTrainedModel
,
)
from
transformers.models.gemma3.processing_gemma3
import
Gemma3ProcessorKwargs
from
transformers
import
AutoModel
,
Gemma3Config
,
PreTrainedModel
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
...
...
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
flatten_nested_list
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
...
...
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"""
return
self
.
language_model
.
get_attention_sliding_window_size
()
def
get_image_feature
(
self
,
i
mage_input
:
Multimodal
Inputs
):
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values
=
image_input
.
pixel_values
pixel_values
=
torch
.
stack
(
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
]),
dim
=
0
)
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
...
...
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features
=
self
.
multi_modal_projector
(
vision_outputs
)
return
image_features
def
embed_mm_inputs
(
self
,
input_ids
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
image_input
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
if
input_ids
is
None
:
raise
ValueError
(
"Unimplemented"
)
# boolean-masking image tokens
special_image_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
image_input
.
pad_values
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
inputs_embeds
=
None
if
num_image_tokens_in_input_ids
==
0
:
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
return
inputs_embeds
else
:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features
=
self
.
get_image_feature
(
image_input
.
pixel_values
)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding
=
(
image_features
.
shape
[
0
]
*
image_features
.
shape
[
1
]
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
num_image
=
num_image_tokens_in_input_ids
//
image_features
.
shape
[
1
]
image_features
=
image_features
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
image_features
=
image_features
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
image_features
)
return
inputs_embeds
@
torch
.
no_grad
()
def
forward
(
self
,
...
...
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else
:
llm_input_ids
=
input_ids
inputs_embed
s
=
general_mm_embed_routine
(
h
s
=
general_mm_embed_routine
(
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
outputs
=
self
.
language_model
(
input_ids
=
None
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
output
s
return
h
s
def
tie_weights
(
self
):
return
self
.
language_model
.
tie_weights
()
...
...
python/sglang/srt/models/llama.py
View file @
5cb552b1
...
...
@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
def
get_hidden_dim
(
self
,
module_name
):
# return input_dim, output_dim
if
module_name
in
[
"q_proj"
,
"o_proj"
,
"qkv_proj"
]:
...
...
python/sglang/srt/models/llava.py
View file @
5cb552b1
...
...
@@ -31,7 +31,7 @@ from transformers import (
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalInputs
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
...
...
@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
class
LlavaBaseForCausalLM
(
nn
.
Module
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
image_sizes
=
flatten_nested_list
(
[
item
.
image_sizes
for
item
in
image_inputs
.
mm_items
]
)
pad_values
=
[
item
.
pad_value
for
item
in
image_inputs
.
mm_items
]
# hardcode for spatial_unpad + anyres
if
image_inputs
.
modalities
is
not
None
and
(
"multi-images"
in
image_inputs
.
modalities
or
"video"
in
image_inputs
.
m
odalitie
s
if
any
(
item
.
modality
==
Modality
.
MULTI_IMAGES
or
item
.
modality
==
Modality
.
VIDEO
f
or
item
in
image_inputs
.
m
m_item
s
):
image_aspect_ratio
=
"pad"
else
:
...
...
@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math
.
ceil
(
self
.
image_size
/
self
.
patch_size
/
2
)
**
2
)
else
:
new_image_feature_len
=
self
.
image_feature_len
# multiimage
new_image_feature_len
=
self
.
image_feature_len
# multi
-
image
height
=
width
=
self
.
num_patches_per_side
if
"anyres"
in
image_aspect_ratio
:
...
...
@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids
=
(
input_ids
[:
offset
]
+
[
pad_values
[
image_idx
]]
*
new_image_feature_len
+
[
pad_values
[
image_idx
%
len
(
pad_values
)
]]
*
new_image_feature_len
+
input_ids
[
offset
+
1
:]
)
offset_list
.
append
(
offset
)
...
...
@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list
=
[]
max_image_offset
=
[]
for
im
in
image_inputs
:
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
i
m
.
modalit
ies
)
if
im
:
modalities_list
.
extend
(
[
ite
m
.
modalit
y
for
item
in
im
.
mm_items
]
)
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
np
.
max
(
np
.
array
(
im
.
image_offsets
)
+
np
.
array
(
im
.
image_pad_len
))
...
...
@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if
need_vision
.
any
():
bs
=
forward_batch
.
batch_size
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
)
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
flatten_nested_list
(
[
item
.
image_sizes
for
item
in
image_inputs
[
i
].
mm_items
]
)
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features
=
[]
height
=
width
=
self
.
num_patches_per_side
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
modalities_list
[
image_idx
]
==
"image"
:
if
modalities_list
[
image_idx
]
==
Modality
.
IMAGE
:
image_aspect_ratio
=
(
self
.
config
.
image_aspect_ratio
)
# single image
elif
(
modalities_list
[
image_idx
]
==
"multi-images"
or
modalities_list
[
image_idx
]
==
"video"
modalities_list
[
image_idx
]
==
Modality
.
MULTI_IMAGES
or
modalities_list
[
image_idx
]
==
Modality
.
VIDEO
):
image_aspect_ratio
=
"pad"
# multi image
# image_aspect_ratio = (
...
...
@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if
(
image_feature
.
shape
[
0
]
>
1
and
"anyres"
in
image_aspect_ratio
and
modalities_list
[
image_idx
]
==
"image"
and
modalities_list
[
image_idx
]
==
Modality
.
IMAGE
):
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
...
...
@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature
=
image_feature
.
unsqueeze
(
0
)
else
:
if
modalities_list
[
image_idx
]
==
"video"
:
# video
if
modalities_list
[
image_idx
]
==
Modality
.
VIDEO
:
# video
# 2x2 pooling
num_of_frames
=
image_feature
.
shape
[
0
]
image_feature
=
image_feature
.
view
(
...
...
python/sglang/srt/models/llavavid.py
View file @
5cb552b1
...
...
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from
transformers.models.llava.modeling_llava
import
LlavaMultiModalProjector
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
,
flatten_nested_list
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
pad_values
=
image_inputs
.
pad_values
pad_values
=
[
item
.
pad_value
for
item
in
image_inputs
.
mm_items
]
new_image_feature_len
=
self
.
image_feature_len
pad_ids
=
pad_values
*
(
...
...
@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
)
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
flatten_nested_list
(
[
item
.
image_offsets
for
item
in
image_inputs
[
i
].
mm_items
]
)
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_resampler.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.vision_resampler.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline"
:
"language_model.model.image_newline"
,
}
params_dict
=
dict
(
self
.
named_parameters
())
...
...
python/sglang/srt/models/minicpmo.py
View file @
5cb552b1
...
...
@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
embed_mm_inputs
,
get_multimodal_data_bounds
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.managers.schedule_batch
import
(
MultimodalDataItem
,
MultimodalInputs
,
flatten_nested_list
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.minicpmv
import
(
Idefics2VisionTransformer
,
MiniCPM
V
BaseModel
,
MiniCPMBaseModel
,
Resampler2_5
,
)
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
...
...
@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
return
hidden_states
class
MiniCPMO
(
MiniCPM
V
BaseModel
):
class
MiniCPMO
(
MiniCPMBaseModel
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
return
input_lengths_after_cnn
,
input_lengths_after_pooling
def
get_audio_embedding_streaming
(
self
,
multimodal_input
:
Multimodal
Inputs
):
def
get_audio_embedding_streaming
(
self
,
items
:
List
[
Multimodal
DataItem
]
):
r
"""
Extract audio embeddings in a streaming manner using cached key-value pairs.
...
...
@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
wavforms
=
(
[]
if
multimodal_input
.
audio_features
is
None
else
multimodal_input
.
audio_features
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
(
[]
if
multimodal_input
.
audio_feature_lens
is
None
else
multimodal_input
.
audio_feature_lens
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
)
# exist audio
...
...
@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
ret
[
i
,
start
:
ending
]
=
True
return
ret
def
get_audio_embedding
(
self
,
multimodal_input
:
Multimodal
Inputs
,
chunk_length
=-
1
):
def
get_audio_embedding
(
self
,
items
:
List
[
Multimodal
DataItem
]
,
chunk_length
=-
1
):
r
"""
Extract full audio embeddings with optional chunk-based attention.
...
...
@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
not use key-value caching and is suitable for non-streaming inference.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
attention (>0) during embedding computation.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms
=
(
[]
if
multimodal_input
.
audio_features
is
None
else
multimodal_input
.
audio_features
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
(
[]
if
multimodal_input
.
audio_feature_lens
is
None
else
multimodal_input
.
audio_feature_lens
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
)
final_audio_embeds
=
[]
assert
isinstance
(
wavforms
,
list
)
assert
isinstance
(
wavforms
[
0
],
torch
.
Tensor
)
# exist audio
for
wavform
in
wavforms
:
if
len
(
wavform
)
>
0
:
...
...
@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
final_audio_embeds
.
append
(
target_audio_embeds
)
return
final_audio_embeds
def
get_audio_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
embedding
=
self
.
get_omni_embedding
(
items
=
items
,
chunk_length
=
self
.
config
.
audio_chunk_length
,
stream_input
=
False
,
)
return
embedding
def
get_omni_embedding
(
self
,
input_ids
,
multimodal_input
:
MultimodalInputs
,
input_embeds
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
items
:
List
[
MultimodalDataItem
],
chunk_length
=-
1
,
stream_input
=
False
,
):
"""
Args:
multimodal_input:
input_embeds:
chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding
Returns:
final embeddings with audio feature
"""
input_embeds
=
input_embeds
.
unsqueeze
(
0
)
if
not
forward_mode
.
is_decode
()
and
multimodal_input
.
contains_audio_inputs
():
audio_bounds
=
get_multimodal_data_bounds
(
input_ids
=
input_ids
,
pad_values
=
multimodal_input
.
pad_values
,
token_pairs
=
[
(
multimodal_input
.
audio_start_id
,
multimodal_input
.
audio_end_id
)
],
)
if
audio_bounds
.
numel
()
==
0
:
input_embeds
=
input_embeds
.
squeeze
(
0
)
# TODO
logger
.
warn
(
"Unimplemented logic. Please try disabling chunked prefill"
)
return
input_embeds
audio_bounds
=
audio_bounds
.
unsqueeze
(
0
)
bs
=
len
(
input_embeds
)
if
stream_input
:
audio_embeddings
=
self
.
get_audio_embedding_streaming
(
multimodal_input
)
else
:
audio_embeddings
=
self
.
get_audio_embedding
(
multimodal_input
,
chunk_length
)
# batch size
assert
len
(
audio_embeddings
)
==
len
(
input_embeds
)
if
len
(
audio_embeddings
)
>
0
:
if
self
.
config
.
chunk_input
:
for
i
in
range
(
bs
):
audio_embs
=
torch
.
cat
(
audio_embeddings
[
i
],
dim
=
0
).
to
(
device
=
input_embeds
.
device
,
dtype
=
input_embeds
.
dtype
)
audio_start_pos
=
0
for
bound
in
audio_bounds
[
i
]:
audio_len
=
bound
[
1
]
-
bound
[
0
]
+
1
input_embeds
[
0
,
bound
[
0
]
:
bound
[
1
]
+
1
]
=
audio_embs
[
audio_start_pos
:
audio_start_pos
+
audio_len
,
:
]
audio_start_pos
+=
audio_len
else
:
for
i
in
range
(
bs
):
audio_embs
=
audio_embeddings
[
i
]
bounds
=
audio_bounds
[
i
]
for
embs
,
bound
in
zip
(
audio_embs
,
bounds
):
audio_indices
=
torch
.
arange
(
bound
[
0
],
bound
[
1
],
dtype
=
torch
.
long
).
to
(
input_embeds
.
device
)
if
embs
.
shape
[
0
]
!=
len
(
audio_indices
):
raise
ValueError
(
f
"Shape mismatch: Trying to assign embeddings of shape
{
embs
.
shape
}
"
f
"to input indices of length
{
len
(
audio_indices
)
}
"
)
input_embeds
[
i
,
audio_indices
]
=
embs
.
to
(
input_embeds
.
dtype
)
input_embeds
=
input_embeds
.
squeeze
(
0
)
return
input_embeds
def
get_image_features
(
self
,
image_inputs
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
pixel_values
=
image_inputs
.
pixel_values
tgt_sizes
=
image_inputs
.
tgt_sizes
if
stream_input
:
audio_embeddings
=
self
.
get_audio_embedding_streaming
(
items
)
else
:
audio_embeddings
=
self
.
get_audio_embedding
(
items
,
chunk_length
)
bs
=
len
(
audio_embeddings
)
# batch size
audio_embs
=
torch
.
cat
(
flatten_nested_list
(
audio_embeddings
),
dim
=
0
)
return
audio_embs
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
assert
len
(
pixel_values
)
==
tgt_sizes
.
shape
[
0
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
all_pixel_values_lst
=
[
...
...
@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
(
...
...
@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
inputs_embeds
=
None
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
forward_batch
.
contains_image_inputs
()
):
mm_inputs
=
forward_batch
.
merge_mm_inputs
()
inputs_embeds
=
embed_mm_inputs
(
mm_input
=
mm_inputs
,
input_ids
=
input_ids
,
input_embedding
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_features
,
placeholder_token_ids
=
[
mm_inputs
.
im_token_id
]
+
mm_inputs
.
pad_values
,
)
input_ids
=
input_ids
.
clamp
(
min
=
0
,
max
=
self
.
get_input_embeddings
().
num_embeddings
-
1
mm_input
=
forward_batch
.
merge_mm_inputs
()
placeholder_token_ids
=
(
([
mm_input
.
im_token_id
]
+
[
item
.
pad_value
for
item
in
mm_input
.
mm_items
])
if
forward_batch
.
contains_mm_inputs
()
else
[]
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
self
.
config
.
init_audio
and
forward_batch
.
contains_audio_inputs
()
):
mm_input
=
forward_batch
.
merge_mm_inputs
()
inputs_embeds
=
self
.
get_omni_embedding
(
input_ids
=
input_ids
,
multimodal_input
=
mm_input
,
input_embeds
=
inputs_embeds
,
forward_mode
=
forward_batch
.
forward_mode
,
chunk_length
=
self
.
config
.
audio_chunk_length
,
stream_input
=
False
,
)
forward_batch
.
mm_inputs
=
None
hidden_states
=
self
.
llm
.
model
(
input_ids
=
None
,
positions
=
positions
,
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
llm
.
lm_head
,
forward_batch
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
audio_data_embedding_func
=
self
.
get_audio_feature
,
placeholder_token_ids
=
placeholder_token_ids
,
positions
=
positions
,
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/minicpmv.py
View file @
5cb552b1
...
...
@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Config
,
Qwen2ForCausalLM
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
RawImageType
=
Union
[
Image
.
Image
,
torch
.
Tensor
]
...
...
@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return
tuple
(
int
(
x
)
for
x
in
version_str
.
split
(
"."
))
class
MiniCPM
V
BaseModel
(
nn
.
Module
):
class
MiniCPMBaseModel
(
nn
.
Module
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
...
...
@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
return
vlm_embedding
,
vision_hidden_states
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
llm
.
get_input_embedding
()
return
self
.
llm
.
get_input_embedding
s
()
def
forward
(
self
,
...
...
@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch
:
ForwardBatch
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
inputs_embed
s
=
general_mm_embed_routine
(
hidden_state
s
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_features
,
)
hidden_states
=
self
.
llm
.
model
(
input_ids
=
None
,
image_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
llm
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
llm
.
lm_head
,
forward_batch
)
return
hidden_states
def
init_llm
(
self
,
...
...
@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_image_feature
s
(
self
,
i
mage_inputs
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
MiniCPMV2_6
(
MiniCPM
V
BaseModel
):
class
MiniCPMV2_6
(
MiniCPMBaseModel
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return
vision_embedding
def
get_image_features
(
self
,
image_inputs
:
MultimodalInputs
,
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
image_inputs
.
pixel_values
tgt_sizes
=
image_inputs
.
tgt_sizes
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
assert
len
(
pixel_values
)
==
tgt_sizes
.
shape
[
0
]
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
...
...
@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
max_patches
=
(
tgt_sizes
[:,
0
]
*
tgt_sizes
[:,
1
]).
max
().
item
()
assert
isinstance
(
max_patches
,
int
)
all_pixel_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
all_pixel_values_lst
,
batch_first
=
True
,
padding_value
=
0.0
)
B
,
L
,
_
=
all_pixel_values
.
shape
all_pixel_values
=
all_pixel_values
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
3
,
-
1
,
L
)
patch_attn_mask
=
torch
.
zeros
(
...
...
python/sglang/srt/models/mllama.py
View file @
5cb552b1
...
...
@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
capture_mode
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
pixel_values
=
image_inputs
.
pixel_values
pad_values
=
image_inputs
.
pad_values
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_inputs
.
mm_items
],
dim
=
0
)
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
num_concurrent_media
,
num_tiles
=
pixel_values
.
shape
[
1
:
3
]
num_patches
=
self
.
vision_model
.
num_patches
image_len
=
num_concurrent_media
*
num_tiles
*
num_patches
image
_inputs
.
num_image_tokens
=
image_len
mm
_inputs
.
num_image_tokens
=
image_len
pad_ids
=
pad_values
*
((
image_len
+
len
(
pad_values
))
//
len
(
pad_values
))
...
...
@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images
=
max_num_tiles
=
bs
=
0
for
i
,
im
in
enumerate
(
forward_batch
.
mm_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
im
is
not
None
:
max_num_images
=
max
(
max_num_images
,
im
.
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
im
.
pixel_values
.
shape
[
2
])
for
i
,
mm_input
in
enumerate
(
forward_batch
.
mm_inputs
):
if
not
forward_batch
.
encoder_cached
[
i
]
and
mm_input
is
not
None
:
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images
=
max
(
max_num_images
,
pixel_values
.
shape
[
1
])
max_num_tiles
=
max
(
max_num_tiles
,
pixel_values
.
shape
[
2
])
bs
+=
1
if
max_num_images
*
max_num_tiles
*
bs
==
0
:
...
...
@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
)
i
=
0
encoder_lens_need
=
[]
for
k
,
im
in
enumerate
(
forward_batch
.
mm_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
im
is
None
:
for
k
,
mm_input
in
enumerate
(
forward_batch
.
mm_inputs
):
if
forward_batch
.
encoder_cached
[
k
]
or
mm_input
is
None
:
continue
encoder_lens_need
.
append
(
forward_batch
.
encoder_lens
[
k
])
for
j
in
range
(
im
.
pixel_values
.
shape
[
1
]):
img
=
im
.
pixel_values
[
0
,
j
]
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
for
j
in
range
(
pixel_values
.
shape
[
1
]):
img
=
pixel_values
[
0
,
j
]
num_tiles
=
img
.
shape
[
0
]
batched_images
[
i
,
j
,
:
num_tiles
]
=
img
batched_ar_ids
[
i
,
j
]
=
im
.
aspect_ratio_ids
[
0
,
j
]
batched_ar_mask
[
i
,
j
,
:
num_tiles
]
=
im
.
aspect_ratio_mask
[
0
,
j
]
batched_ar_ids
[
i
,
j
]
=
mm_input
.
mm_items
[
0
].
aspect_ratio_id
[
0
,
j
]
batched_ar_mask
[
i
,
j
,
:
num_tiles
]
=
mm_input
.
mm_items
[
0
].
aspect_ratio_mask
[
0
,
j
]
i
+=
1
return
batched_images
,
batched_ar_ids
,
batched_ar_mask
,
encoder_lens_need
...
...
python/sglang/srt/models/qwen2.py
View file @
5cb552b1
...
...
@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embedding
s
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
.
config
,
"scale_emb"
):
return
self
.
embed_tokens
(
input_ids
)
*
self
.
config
.
scale_emb
return
self
.
get_input_embeddings
()
(
input_ids
)
*
self
.
config
.
scale_emb
else
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
get_input_embeddings
()(
input_ids
)
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
embed_tokens
def
forward
(
self
,
...
...
@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
get_input_embedding
s
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
s
(
input_ids
)
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
(
input_ids
)
def
get_input_embedding
(
self
)
->
nn
.
Embedding
:
def
get_input_embedding
s
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
5cb552b1
...
...
@@ -30,22 +30,13 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
AutoModel
,
Qwen2VLConfig
from
transformers
import
Qwen2VLConfig
from
transformers.activations
import
ACT2FN
from
transformers.models.qwen2.modeling_qwen2
import
Qwen2RMSNorm
from
transformers.models.qwen2_5_vl
import
Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
,
)
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
(
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.hf_transformers_utils
import
get_processor
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
...
...
@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image
_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm
_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
image
_inputs
.
im_start_id
im_end_id
:
int
=
image
_inputs
.
im_end_id
im_start_id
:
int
=
mm
_inputs
.
im_start_id
im_end_id
:
int
=
mm
_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thws
=
torch
.
concat
([
item
.
image_grid_thws
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thws
.
dim
()
==
2
,
image_grid_thws
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thws
)
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embed
s
=
general_mm_embed_routine
(
hidden_state
s
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
input_ids
=
None
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
if
not
get_embedding
:
...
...
@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"
.
qkv_proj"
,
"
.
q_proj"
,
"q"
),
(
"
.
qkv_proj"
,
"
.
k_proj"
,
"k"
),
(
"
.
qkv_proj"
,
"
.
v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
5cb552b1
...
...
@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
...
...
@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
m
ulti_modal
_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
m
m
_inputs
:
MultimodalInputs
):
# Get all special token IDs
im_start_id
:
int
=
m
ulti_modal
_inputs
.
im_start_id
im_end_id
:
int
=
m
ulti_modal
_inputs
.
im_end_id
im_start_id
:
int
=
m
m
_inputs
.
im_start_id
im_end_id
:
int
=
m
m
_inputs
.
im_end_id
media_token_pairs
=
[(
im_start_id
,
im_end_id
)]
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
m
ulti_modal
_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
m
m
_inputs
)
def
get_image_feature
(
self
,
image_input
:
MultimodalInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thws
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thws
=
torch
.
concat
([
item
.
image_grid_thws
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
image_grid_thws
.
dim
()
==
2
,
image_grid_thws
.
dim
()
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_grid_thws
)
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
...
@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embeds
=
general_mm_embed_routine
(
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
hidden_states
=
self
.
model
(
input_ids
=
None
,
language_model
=
self
.
model
,
image_data_embedding_func
=
self
.
get_image_feature
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
)
if
not
get_embedding
:
if
get_embedding
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
...
...
python/sglang/srt/openai_api/adapter.py
View file @
5cb552b1
...
...
@@ -897,6 +897,7 @@ def v1_chat_generate_request(
request_ids
:
List
[
str
]
=
None
,
):
input_ids
=
[]
prompts
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
audio_data_list
=
[]
...
...
@@ -916,6 +917,7 @@ def v1_chat_generate_request(
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag
=
None
prompt
=
""
if
not
isinstance
(
request
.
messages
,
str
):
# Apply chat template and its stop strings.
tools
=
None
...
...
@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
image_data
=
None
audio_data
=
None
modalities
=
[]
prompt
=
request
.
messages
input_ids
.
append
(
prompt_ids
)
return_logprobs
.
append
(
request
.
logprobs
)
logprob_start_lens
.
append
(
-
1
)
top_logprobs_nums
.
append
(
request
.
top_logprobs
or
0
)
lora_paths
.
append
(
request
.
lora_path
)
prompts
.
append
(
prompt
)
sampling_params
=
{
"temperature"
:
request
.
temperature
,
...
...
@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
audio_data_list
.
append
(
audio_data
)
modalities_list
.
append
(
modalities
)
if
len
(
all_requests
)
==
1
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
if
tokenizer_manager
.
model_config
.
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]}
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
[
0
]}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
[
0
]}
sampling_params_list
=
sampling_params_list
[
0
]
image_data_list
=
image_data_list
[
0
]
audio_data_list
=
audio_data_list
[
0
]
...
...
@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
modalities_list
=
modalities_list
[
0
]
lora_paths
=
lora_paths
[
0
]
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
if
tokenizer_manager
.
model_config
.
is_multimodal
:
# processor will need text input
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
}
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
else
:
prompt_kwargs
=
{
"input_ids"
:
input_ids
}
adapted_request
=
GenerateReqInput
(
**
prompt_kwargs
,
...
...
python/sglang/srt/utils.py
View file @
5cb552b1
...
...
@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import
base64
import
builtins
import
ctypes
...
...
@@ -54,6 +53,7 @@ import torch.distributed
import
torch.distributed
as
dist
import
triton
import
zmq
from
decord
import
VideoReader
,
cpu
from
fastapi.responses
import
ORJSONResponse
from
packaging
import
version
as
pkg_version
from
PIL
import
Image
...
...
@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
import
soundfile
as
sf
from
scipy.signal
import
resample
# print(f"loading {audio_file}")
# Load audio data
if
isinstance
(
audio_file
,
bytes
):
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
audio_file
))
elif
audio_file
.
startswith
(
"data:"
):
audio_file
=
audio_file
.
split
(
","
)[
1
]
audio
,
original_sr
=
sf
.
read
(
BytesIO
(
base64
.
b64decode
(
audio_file
)))
elif
audio_file
.
startswith
(
"http://"
)
or
audio_file
.
startswith
(
"https://"
):
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"5"
))
response
=
requests
.
get
(
audio_file
,
stream
=
True
,
timeout
=
timeout
)
audio_file
=
BytesIO
(
response
.
content
)
response
.
close
()
audio
,
original_sr
=
sf
.
read
(
audio_file
)
elif
isinstance
(
audio_file
,
str
):
audio
,
original_sr
=
sf
.
read
(
audio_file
)
else
:
...
...
@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return
audio
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_image
(
image_file
:
Union
[
str
,
bytes
])
->
tuple
[
Image
,
tuple
[
int
,
int
]]:
image
=
image_size
=
None
...
...
@@ -1796,3 +1825,12 @@ def retry(
traceback
.
print_exc
()
time
.
sleep
(
delay
)
def
flatten_nested_list
(
nested_list
):
if
isinstance
(
nested_list
,
list
):
return
[
item
for
sublist
in
nested_list
for
item
in
flatten_nested_list
(
sublist
)
]
else
:
return
[
nested_list
]
test/srt/test_vision_openai_server.py
View file @
5cb552b1
...
...
@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"modalities"
:
"multi-images"
,
},
{
...
...
@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
prompt
,
},
{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
f
"
{
audio_file_name
}
"
},
},
{
"type"
:
"text"
,
"text"
:
prompt
,
},
],
}
]
...
...
test/srt/test_vlm_accuracy.py
View file @
5cb552b1
...
...
@@ -3,6 +3,7 @@
import
unittest
from
io
import
BytesIO
from
typing
import
List
import
numpy
as
np
import
requests
...
...
@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
(
Modality
,
MultimodalDataItem
,
MultimodalInputs
,
)
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang
model
=
self
.
get_sglang_model
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
pixel_values
=
inputs
[
"pixel_values"
]
tgt_sizes
=
inputs
[
"tgt_sizes"
]
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
# per image
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
pixel_values_flat
+=
[
pixel_n
]
tgt_sizes_flat
+=
[
tgt_n
]
sglang_output
=
embed_mm_inputs
(
mm_input
=
MultimodalInputs
(
pixel_values
=
inputs
[
"pixel_values"
][
0
],
tgt_sizes
=
inputs
[
"tgt_sizes"
][
0
],
mm_inputs
=
MultimodalInputs
(
mm_items
=
[
MultimodalDataItem
(
pixel_values
=
pixel_values_flat
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
pad_value
=
self
.
processor
.
tokenizer
.
unk_token_id
,
)
]
),
input_ids
=
input_ids
,
input_embedding
=
model
.
get_input_embeddings
(),
mm
_data_embedding_func
=
model
.
get_image_feature
s
,
image
_data_embedding_func
=
model
.
get_image_feature
,
placeholder_token_ids
=
[
self
.
processor
.
tokenizer
.
unk_token_id
,
],
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment