Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11577ced
Unverified
Commit
11577ced
authored
Mar 23, 2025
by
Mick
Committed by
GitHub
Mar 22, 2025
Browse files
refactor: bug fixes and refactor for vlm (#4661)
parent
ca75741e
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
231 additions
and
481 deletions
+231
-481
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+27
-83
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+17
-50
python/sglang/srt/models/gemma3_causal.py
python/sglang/srt/models/gemma3_causal.py
+5
-8
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+13
-18
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+68
-131
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+3
-0
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+30
-95
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+21
-58
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+29
-24
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+18
-14
No files found.
python/sglang/srt/models/deepseek_janus_pro.py
View file @
11577ced
...
@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
...
@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.managers.m
ulti_modality_padding
import
(
from
sglang.srt.managers.m
m_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
prepare_images_seq_mask
(
def
get_image_feature
(
self
,
image_input
:
ImageInputs
)
->
torch
.
Tensor
:
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
ImageInput
s
pixel_values
=
image_input
.
pixel_value
s
)
->
Optional
[
torch
.
LongTensor
]:
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
images_seq_mask
=
torch
.
isin
(
pixel_values
=
pixel_values
.
to
(
input_ids
,
torch
.
tensor
(
image_inputs
.
pad_values
,
device
=
input_ids
.
device
)
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
)
)
if
images_seq_mask
.
sum
()
==
0
:
images
=
rearrange
(
pixel_values
,
"b n c h w -> (b n) c h w"
)
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return
None
# [b x n, T2, D]
else
:
images_embeds
=
self
.
aligner
(
self
.
vision_model
(
images
))
return
images_seq_mask
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds
=
rearrange
(
images_embeds
,
"(b n) t d -> b (n t) d"
,
b
=
bs
,
n
=
n
)
return
images_embeds
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
model
.
embed_tokens
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -1978,86 +1986,22 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1978,86 +1986,22 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
None
inputs_embeds
=
general_mm_embed_routine
(
if
(
input_ids
=
input_ids
,
forward_batch
.
image_inputs
is
not
None
positions
=
positions
,
and
len
(
forward_batch
.
image_inputs
)
!=
0
forward_batch
=
forward_batch
,
and
forward_batch
.
image_inputs
[
0
]
is
not
None
embed_tokens
=
self
.
get_input_embeddings
(),
):
image_embedding_func
=
self
.
get_image_feature
,
)
image_inputs
=
forward_batch
.
image_inputs
[
0
]
images_seq_mask
=
self
.
prepare_images_seq_mask
(
input_ids
=
input_ids
,
image_inputs
=
image_inputs
)
if
images_seq_mask
is
not
None
:
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
inputs_embeds
=
self
.
prepare_inputs_embeds
(
input_ids
=
input_ids
,
pixel_values
=
image_inputs
.
pixel_values
,
images_seq_mask
=
images_seq_mask
,
images_emb_mask
=
image_inputs
.
images_emb_mask
,
)
input_ids
=
None
if
input_ids
is
not
None
:
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
return
self
.
language_model
(
return
self
.
language_model
(
input_ids
=
input_ids
,
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
input_embeds
=
inputs_embeds
,
get_embedding
=
False
,
get_embedding
=
False
,
)
)
def
prepare_inputs_embeds
(
self
,
input_ids
:
torch
.
LongTensor
,
pixel_values
:
torch
.
FloatTensor
,
images_seq_mask
:
torch
.
LongTensor
,
images_emb_mask
:
torch
.
BoolTensor
,
**
_kwargs
,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
)
images
=
rearrange
(
pixel_values
,
"b n c h w -> (b n) c h w"
)
# [b x n, T2, D]
images_embeds
=
self
.
aligner
(
self
.
vision_model
(
images
))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds
=
rearrange
(
images_embeds
,
"(b n) t d -> b (n t) d"
,
b
=
bs
,
n
=
n
)
# [b, n, T2] -> [b, n x T2]
images_emb_mask
=
rearrange
(
images_emb_mask
,
"b n t -> b (n t)"
)
# [b, T, D]
# ignore the image embeddings
input_ids
[
input_ids
<
0
]
=
0
inputs_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# replace with the image embeddings
inputs_embeds
[
images_seq_mask
]
=
images_embeds
[
images_emb_mask
]
return
inputs_embeds
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
11577ced
import
collections
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
itertools
import
math
import
warnings
from
enum
import
Enum
from
functools
import
partial
from
typing
import
Callable
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.configs
import
DeepseekVL2Config
from
sglang.srt.configs.deepseekvl2
import
(
from
sglang.srt.configs.deepseekvl2
import
(
DeepseekVL2Config
,
DeepseekVL2Config
,
DeepseekVL2MlpProjectorConfig
,
DeepseekVL2MlpProjectorConfig
,
)
)
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
**
kwargs
:
object
,
**
kwargs
:
object
,
):
):
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
image_inputs
!=
[
if
(
None
forward_batch
.
forward_mode
.
is_extend
()
]:
and
forward_batch
.
contains_image_inputs
()
):
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
extend_seq_lens_cpu
=
forward_batch
.
extend_seq_lens
.
cpu
().
numpy
()
for
idx
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
for
idx
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
...
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue
continue
start_idx
=
extend_start_loc_cpu
[
idx
]
start_idx
=
extend_start_loc_cpu
[
idx
]
end_idx
=
start_idx
+
extend_seq_lens_cpu
[
idx
]
end_idx
=
start_idx
+
extend_seq_lens_cpu
[
idx
]
pixel_values
=
image
.
pixel_values
.
to
(
images_emb_mask
=
image
.
images_emb_mask
.
to
(
device
=
"cuda"
)
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
image_features
=
self
.
get_image_feature
(
image
)
)
input_embeds
[
start_idx
:
end_idx
]
=
input_embeds
[
image_seq_mask
=
image
.
image_seq_mask
.
to
(
device
=
"cuda"
)
start_idx
:
end_idx
image_spatial_crop
=
image
.
image_spatial_crop
].
masked_scatter
(
images_emb_mask
.
unsqueeze
(
-
1
),
image_features
)
input_embeds
[
start_idx
:
end_idx
]
=
self
.
prepare_inputs_embeds
(
pixel_values
,
image_seq_mask
,
image_spatial_crop
,
input_embeds
[
start_idx
:
end_idx
],
)
outputs
=
self
.
language_model
.
forward
(
outputs
=
self
.
language_model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
return
input_ids
return
input_ids
def
prepare_inputs_embeds
(
def
get_image_feature
(
self
,
image_input
:
ImageInputs
):
self
,
pixel_values
=
image_input
.
pixel_values
.
type
(
pixel_values
,
next
(
self
.
vision
.
parameters
()).
dtype
images_seq_mask
,
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
images_spatial_crop
,
input_embeds
,
):
image_feature
=
self
.
vision
.
forward_features
(
pixel_values
)
image_feature
=
self
.
vision
.
forward_features
(
pixel_values
)
images_embeds
=
self
.
projector
(
image_feature
)
images_embeds
=
self
.
projector
(
image_feature
)
_
,
hw
,
n_dim
=
images_embeds
.
shape
_
,
hw
,
n_dim
=
images_embeds
.
shape
h
=
w
=
int
(
hw
**
0.5
)
h
=
w
=
int
(
hw
**
0.5
)
tile_index
=
0
tile_index
=
0
images_in_this_batch
=
[]
images_in_this_batch
=
[]
images_spatial_crop
=
image_input
.
image_spatial_crop
for
jdx
in
range
(
images_spatial_crop
.
shape
[
1
]):
for
jdx
in
range
(
images_spatial_crop
.
shape
[
1
]):
num_width_tiles
,
num_height_tiles
=
images_spatial_crop
[
0
,
jdx
]
num_width_tiles
,
num_height_tiles
=
images_spatial_crop
[
0
,
jdx
]
if
num_width_tiles
==
0
or
num_height_tiles
==
0
:
if
num_width_tiles
==
0
or
num_height_tiles
==
0
:
...
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
...
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch
.
append
(
global_local_features
)
images_in_this_batch
.
append
(
global_local_features
)
if
len
(
images_in_this_batch
)
>
0
:
return
torch
.
cat
(
images_in_this_batch
,
dim
=
0
)
images_in_this_batch
=
torch
.
cat
(
images_in_this_batch
,
dim
=
0
)
input_embeds
.
masked_scatter_
(
images_seq_mask
.
unsqueeze
(
-
1
),
images_in_this_batch
)
return
input_embeds
EntryClass
=
DeepseekVL2ForCausalLM
EntryClass
=
DeepseekVL2ForCausalLM
python/sglang/srt/models/gemma3_causal.py
View file @
11577ced
...
@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
...
@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
,
get_rope
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
...
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
...
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else
:
else
:
hidden_states
=
input_embeds
hidden_states
=
input_embeds
if
len
(
positions
.
shape
)
==
1
:
if
positions
.
dim
(
)
==
1
:
positions
=
einops
.
rearrange
(
positions
,
"s -> 1 s"
)
positions
=
einops
.
rearrange
(
positions
,
"s -> 1 s"
)
position_embeddings_global
=
self
.
rotary_emb
(
hidden_states
,
positions
)
position_embeddings_global
=
self
.
rotary_emb
(
hidden_states
,
positions
)
...
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
...
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
)
)
self
.
post_init
()
self
.
post_init
()
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
return
self
.
model
.
embed_tokens
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
model
.
layers
[
0
].
mlp
.
gate_up_proj
.
weight
.
dtype
return
next
(
self
.
parameters
())
.
dtype
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
11577ced
...
@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
...
@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
from
sglang.srt.layers.layernorm
import
Gemma3RMSNorm
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.m
ulti_modality_padding
import
(
from
sglang.srt.managers.m
m_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
kwargs
[
"local_attn_masks"
]
=
local_attn_masks
kwargs
[
"local_attn_masks"
]
=
local_attn_masks
return
kwargs
return
kwargs
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
get_input_embeddings
()
return
self
.
language_model
.
get_input_embeddings
()
def
get_image_feature
s
(
self
,
pixel_values
:
torch
.
Tensor
):
def
get_image_feature
(
self
,
image_input
:
ImageInputs
):
"""
"""
Projects the last hidden state from the vision model into language model space.
Projects the last hidden state from the vision model into language model space.
...
@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
"""
pixel_values
=
image_input
.
pixel_values
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
"cuda"
)
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
language_model
.
dtype
())
...
@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return
inputs_embeds
return
inputs_embeds
else
:
else
:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features
=
self
.
get_image_feature
s
(
image_input
.
pixel_values
)
image_features
=
self
.
get_image_feature
(
image_input
.
pixel_values
)
# print(f"image tokens from image embeddings: {image_features.numel()}")
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding
=
(
num_image_tokens_in_embedding
=
(
...
@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
...
@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else
:
else
:
llm_input_ids
=
input_ids
llm_input_ids
=
input_ids
merged_image_input
=
forward_batch
.
get_merged_image_inputs
()
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
llm_input_ids
,
if
(
positions
=
positions
,
not
forward_batch
.
forward_mode
.
is_decode
()
forward_batch
=
forward_batch
,
and
merged_image_input
is
not
None
embed_tokens
=
self
.
get_input_embeddings
(),
):
image_embedding_func
=
self
.
get_image_feature
,
inputs_embeds
=
self
.
embed_image_inputs
(
)
input_ids
=
llm_input_ids
,
forward_batch
=
forward_batch
,
image_input
=
merged_image_input
,
)
else
:
llm_input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
vocab_size
-
1
)
inputs_embeds
=
self
.
get_input_embeddings
()(
llm_input_ids
)
outputs
=
self
.
language_model
(
outputs
=
self
.
language_model
(
input_ids
=
None
,
input_ids
=
None
,
...
...
python/sglang/srt/models/minicpmv.py
View file @
11577ced
...
@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
...
@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.m
ulti_modality_padding
import
(
from
sglang.srt.managers.m
m_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
embed_image_inputs
,
)
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
...
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
)
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
embeddings
return
self
.
embeddings
def
compute_cu_seqlens
(
self
,
tgt_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
compute_cu_seqlens
(
self
,
tgt_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
return
valid_pairs_tensor
return
valid_pairs_tensor
def
get_embedding
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
image_inputs
is
None
:
# No image
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
else
:
if
image_inputs
[
"type"
]
==
"image_embeds"
:
vision_hidden_states
=
(
image_inputs
[
"data"
]
.
type
(
vlm_embedding
.
dtype
)
.
to
(
vlm_embedding
.
device
)
)
else
:
vision_hidden_states
=
self
.
get_vision_hidden_states
(
image_inputs
)
# See NOTE in _parse_and_validate_inputs
image_bounds
=
image_inputs
[
"image_bounds"
]
if
len
(
image_bounds
)
>
0
:
image_indices
=
torch
.
stack
(
[
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
)
for
start
,
end
in
image_bounds
.
tolist
()
]
).
to
(
vlm_embedding
.
device
)
vlm_embedding
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
vlm_embedding
.
shape
[
-
1
]),
vision_hidden_states
.
view
(
-
1
,
vision_hidden_states
.
shape
[
-
1
]),
)
return
vlm_embedding
,
vision_hidden_states
def
_parse_and_validate_inputs
(
def
_parse_and_validate_inputs
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
type
=
"image_embeds"
,
type
=
"image_embeds"
,
)
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
tgt_sizes
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of target sizes. "
f
"Got type:
{
type
(
tgt_sizes
)
}
"
)
if
len
(
pixel_values
)
!=
len
(
tgt_sizes
):
raise
ValueError
(
"Inconsistent batch lengths, found: "
f
"
{
len
(
pixel_values
)
}
vs.
{
len
(
tgt_sizes
)
}
"
)
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
pixel_values_flat
+=
pixel_n
tgt_sizes_flat
+=
tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if
len
(
pixel_values_flat
)
!=
len
(
tgt_sizes_flat
):
raise
ValueError
(
"Inconsistent flattened lengths, found: "
f
"
{
len
(
pixel_values_flat
)
}
vs. "
f
"
{
len
(
tgt_sizes_flat
)
}
"
)
if
len
(
pixel_values_flat
)
==
0
:
return
None
image_bounds
=
self
.
_get_image_bounds
(
image_bounds
=
self
.
_get_image_bounds
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
pad_values
=
pad_values
,
pad_values
=
pad_values
,
...
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
)
)
return
MiniCPMVImagePixelInputs
(
return
MiniCPMVImagePixelInputs
(
image_bounds
=
image_bounds
.
to
(
device
=
input_ids
.
device
),
image_bounds
=
image_bounds
.
to
(
device
=
input_ids
.
device
),
data
=
pixel_values
_flat
,
data
=
pixel_values
,
tgt_sizes
=
torch
.
stack
(
tgt_sizes
_flat
)
,
tgt_sizes
=
tgt_sizes
,
type
=
"pixel_values"
,
type
=
"pixel_values"
,
)
)
def
get_embedding
(
self
,
input_ids
:
torch
.
Tensor
,
image_inputs
:
Optional
[
MiniCPMVImageInputs
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vlm_embedding
:
torch
.
Tensor
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
image_inputs
is
None
:
# No image
vision_hidden_states
=
torch
.
tensor
([],
device
=
input_ids
.
device
)
else
:
if
image_inputs
[
"type"
]
==
"image_embeds"
:
vision_hidden_states
=
(
image_inputs
[
"data"
]
.
type
(
vlm_embedding
.
dtype
)
.
to
(
vlm_embedding
.
device
)
)
else
:
vision_hidden_states
=
self
.
get_vision_hidden_states
(
image_inputs
)
# See NOTE in _parse_and_validate_inputs
image_bounds
=
image_inputs
[
"image_bounds"
]
if
len
(
image_bounds
)
>
0
:
image_indices
=
torch
.
stack
(
[
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
)
for
start
,
end
in
image_bounds
.
tolist
()
]
).
to
(
vlm_embedding
.
device
)
vlm_embedding
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
vlm_embedding
.
shape
[
-
1
]),
vision_hidden_states
.
view
(
-
1
,
vision_hidden_states
.
shape
[
-
1
]),
)
return
vlm_embedding
,
vision_hidden_states
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
llm
.
get_input_embedding
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
(
if
(
forward_batch
.
image_inputs
is
not
None
forward_batch
.
forward_mode
.
is_decode
()
and
len
(
forward_batch
.
image_inputs
)
>
0
or
not
forward_batch
.
contains_image_inputs
()
and
forward_batch
.
image_inputs
[
0
]
is
not
None
):
):
# TODO: bath
inputs_embeds
:
torch
.
Tensor
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
kwargs
.
update
(
else
:
{
# Clamp input ids. This is because the input_ids for the image tokens are
"pixel_values"
:
(
# filled with the hash values of the image for the prefix matching in the radix attention.
None
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
if
forward_batch
.
image_inputs
is
None
image_inputs
=
forward_batch
.
merge_image_inputs
()
else
[
inputs_embeds
=
embed_image_inputs
(
i
.
pixel_values
image_input
=
image_inputs
,
for
i
in
forward_batch
.
image_inputs
input_ids
=
input_ids
,
if
i
is
not
None
input_embedding
=
self
.
get_input_embeddings
(),
]
image_embedding_func
=
self
.
get_image_features
,
),
placeholder_token_ids
=
[
image_inputs
.
im_token_id
]
"tgt_sizes"
:
(
+
image_inputs
.
pad_values
,
None
if
forward_batch
.
image_inputs
is
None
else
[
i
.
tgt_sizes
for
i
in
forward_batch
.
image_inputs
if
i
is
not
None
]
),
"im_start_id"
:
forward_batch
.
image_inputs
[
0
].
im_start_id
,
"im_end_id"
:
forward_batch
.
image_inputs
[
0
].
im_end_id
,
"slice_start_id"
:
forward_batch
.
image_inputs
[
0
].
slice_start_id
,
"slice_end_id"
:
forward_batch
.
image_inputs
[
0
].
slice_end_id
,
"pad_values"
:
forward_batch
.
image_inputs
[
0
].
pad_values
,
}
)
)
image_inputs
=
self
.
_parse_and_validate_inputs
(
input_ids
,
**
kwargs
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
vlm_embeddings
,
_
=
self
.
get_embedding
(
input_ids
,
image_inputs
)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids
=
None
hidden_states
=
self
.
llm
.
model
(
hidden_states
=
self
.
llm
.
model
(
input_ids
=
input_ids
,
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
input_embeds
=
vlm
_embed
ding
s
,
input_embeds
=
inputs
_embeds
,
)
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
...
@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
...
@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
get_
vision_hidden_states
(
self
,
data
:
MiniCPMV
ImageInputs
)
->
torch
.
Tensor
:
def
get_
image_features
(
self
,
image_inputs
:
ImageInputs
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
...
@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
)
return
vision_embedding
return
vision_embedding
def
get_
vision_hidden_stat
es
(
def
get_
image_featur
es
(
self
,
self
,
data
:
MiniCPMV
ImageInputs
,
image_inputs
:
ImageInputs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pixel_values
=
data
[
"data"
]
# list of tensors
tgt_sizes
=
data
[
"tgt_sizes"
]
pixel_values
=
image_inputs
.
pixel_values
tgt_sizes
=
image_inputs
.
tgt_sizes
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
device
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
device
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
dtype
=
self
.
vpm
.
embeddings
.
position_embedding
.
weight
.
dtype
...
...
python/sglang/srt/models/qwen2.py
View file @
11577ced
...
@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
get_input_embedding
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
self
,
self
,
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
11577ced
...
@@ -26,7 +26,6 @@ import logging
...
@@ -26,7 +26,6 @@ import logging
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.m
ulti_modality_padding
import
(
from
sglang.srt.managers.m
m_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2_vl
import
Qwen2VLImageInputs
,
Qwen2VLVideoInputs
from
sglang.srt.models.qwen2_vl
import
Qwen2VLVideoInputs
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
)
def
get_window_index
(
self
,
grid_thw
):
def
get_window_index
(
self
,
grid_thw
):
window_index
:
list
=
[]
cu_window_seqlens
:
list
=
[
0
]
cu_window_seqlens
:
list
=
[
0
]
window_index_id
=
0
window_index_id
=
0
vit_merger_window_size
=
(
vit_merger_window_size
=
(
self
.
window_size
//
self
.
spatial_merge_size
//
self
.
patch_size
self
.
window_size
//
self
.
spatial_merge_size
//
self
.
patch_size
)
)
window_index
:
list
=
[]
for
grid_t
,
grid_h
,
grid_w
in
grid_thw
:
for
grid_t
,
grid_h
,
grid_w
in
grid_thw
:
llm_grid_h
,
llm_grid_w
=
(
llm_grid_h
,
llm_grid_w
=
(
grid_h
//
self
.
spatial_merge_size
,
grid_h
//
self
.
spatial_merge_size
,
...
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens
.
extend
(
cu_seqlens_tmp
.
tolist
())
cu_window_seqlens
.
extend
(
cu_seqlens_tmp
.
tolist
())
window_index_id
+=
(
grid_t
*
llm_grid_h
*
llm_grid_w
).
item
()
window_index_id
+=
(
grid_t
*
llm_grid_h
*
llm_grid_w
).
item
()
window_index
=
torch
.
cat
(
window_index
,
dim
=
0
)
window_index
=
torch
.
cat
(
window_index
,
dim
=
0
)
return
window_index
,
cu_window_seqlens
return
window_index
,
cu_window_seqlens
@
property
@
property
...
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_ids
=
[]
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
for
i
in
range
(
grid_thw
.
size
(
0
)):
t
,
h
,
w
=
grid_thw
[
i
].
tolist
()
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
hpos_ids
.
reshape
(
hpos_ids
=
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
...
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
)
wpos_ids
=
wpos_ids
.
permute
(
0
,
2
,
1
,
3
)
wpos_ids
=
wpos_ids
.
permute
(
0
,
2
,
1
,
3
)
wpos_ids
=
wpos_ids
.
flatten
()
wpos_ids
=
wpos_ids
.
flatten
()
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
max_grid_size
=
grid_thw
[:,
1
:].
max
()
...
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
position_embeddings
=
(
emb
.
cos
(),
emb
.
sin
())
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
cat
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
[
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
torch
.
tensor
([
0
],
device
=
grid_thw
.
device
),
(
grid_thw
[:,
0
]
*
grid_thw
[:,
1
]
*
grid_thw
[:,
2
]).
cumsum
(
dim
=
0
),
]
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
# transformers
# transformers
...
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
calculate_num_image_tokens
(
self
,
image_grid_thw
:
Tuple
[
int
,
int
,
int
]):
processor
=
cached_get_processor
(
self
.
config
.
_name_or_path
)
grid_t
,
grid_h
,
grid_w
=
image_grid_thw
num_image_tokens
=
(
grid_t
*
grid_h
*
grid_w
//
processor
.
image_processor
.
merge_size
//
processor
.
image_processor
.
merge_size
)
return
num_image_tokens
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
# Get all special token IDs
# Get all special token IDs
im_start_id
:
int
=
image_inputs
.
im_start_id
im_start_id
:
int
=
image_inputs
.
im_start_id
...
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
_process_image_input
(
self
,
image_input
:
Qwen2VL
ImageInputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
image_input
:
ImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"
pixel_values
"
]
.
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"
image_grid_thw
"
]
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thw
s
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
)
)
return
video_embeds
return
video_embeds
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
...
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
positions
=
forward_batch
.
mrope_positions
positions
=
forward_batch
.
mrope_positions
image_inputs
=
None
if
not
(
if
forward_batch
.
image_inputs
is
not
None
:
image_inputs
=
[
img
for
img
in
forward_batch
.
image_inputs
if
img
is
not
None
]
if
(
forward_batch
.
forward_mode
.
is_decode
()
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
not
forward_batch
.
contains_image_inputs
()
or
len
(
image_inputs
)
==
0
):
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
)
# Clamp input ids. This is because the input_ids for the image tokens are
inputs_embeds
=
general_mm_embed_routine
(
# filled with the hash values of the image for the prefix matching in the radix attention.
input_ids
=
input_ids
,
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
positions
=
positions
,
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
forward_batch
=
forward_batch
,
# [B, s, hidden_size]
embed_tokens
=
self
.
get_input_embeddings
(),
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
image_embedding_func
=
self
.
get_image_feature
,
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
)
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
is
None
or
image
.
pixel_values
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
pixel_values
=
image
.
pixel_values
.
to
(
device
=
"cuda"
)
image_grid_thws
=
torch
.
tensor
(
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
)
image_offsets
=
image
.
image_offsets
image_input
=
Qwen2VLImageInputs
(
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thws
)
image_embeds
=
self
.
_process_image_input
(
image_input
)
image_embeds_offset
=
0
for
idx
,
image_offset
in
enumerate
(
image_offsets
):
if
image_offset
<
prefix_len
:
continue
num_image_tokens
=
self
.
calculate_num_image_tokens
(
image_grid_thws
[
idx
]
)
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
)
right_idx
=
left_idx
+
num_image_tokens
tp_size
=
get_tensor_model_parallel_world_size
()
hidden_size
=
image_embeds
.
shape
[
-
1
]
if
hidden_size
%
tp_size
!=
0
:
padding_size
=
tp_size
-
(
hidden_size
%
tp_size
)
image_embeds
=
F
.
pad
(
image_embeds
,
(
0
,
padding_size
))
inputs_embeds
=
F
.
pad
(
inputs_embeds
,
(
0
,
padding_size
))
hidden_chunk_size
=
image_embeds
.
shape
[
-
1
]
//
tp_size
rank
=
get_tensor_model_parallel_rank
()
start_dim
=
rank
*
hidden_chunk_size
end_dim
=
(
rank
+
1
)
*
hidden_chunk_size
inputs_embeds
[
left_idx
:
right_idx
,
...,
start_dim
:
end_dim
]
=
(
image_embeds
[
image_embeds_offset
:
image_embeds_offset
+
num_image_tokens
,
...,
start_dim
:
end_dim
,
]
)
image_embeds_offset
+=
num_image_tokens
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
input_embeds
=
inputs_embeds
,
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
11577ced
...
@@ -26,7 +26,6 @@ import logging
...
@@ -26,7 +26,6 @@ import logging
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
TypedDict
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
TypedDict
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
...
@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.managers.m
ulti_modality_padding
import
(
from
sglang.srt.managers.m
m_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
dtype
return
next
(
self
.
parameters
())
.
dtype
@
property
@
property
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
...
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_ids
=
[]
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
for
i
in
range
(
grid_thw
.
size
(
0
)):
t
,
h
,
w
=
grid_thw
[
i
].
tolist
()
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
hpos_ids
=
(
hpos_ids
=
(
...
@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
pattern
=
MultiModalityDataPaddingPatternTokenPairs
(
media_token_pairs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
def
_process_image_input
(
self
,
image_input
:
Qwen2VL
ImageInputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
image_input
:
ImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"
pixel_values
"
]
.
type
(
self
.
visual
.
dtype
)
pixel_values
=
image_input
.
pixel_values
.
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"
image_grid_thw
"
]
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
.
image_grid_thw
s
)
return
image_embeds
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
...
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
)
return
video_embeds
return
video_embeds
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
...
@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
positions
=
forward_batch
.
mrope_positions
positions
=
forward_batch
.
mrope_positions
image_inputs
=
None
if
not
(
if
forward_batch
.
image_inputs
is
not
None
:
image_inputs
=
[
img
for
img
in
forward_batch
.
image_inputs
if
img
is
not
None
]
if
(
forward_batch
.
forward_mode
.
is_decode
()
forward_batch
.
forward_mode
.
is_decode
()
or
image_inputs
is
None
or
not
forward_batch
.
contains_image_inputs
()
or
len
(
image_inputs
)
==
0
):
):
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
)
# Clamp input ids. This is because the input_ids for the image tokens are
inputs_embeds
=
general_mm_embed_routine
(
# filled with the hash values of the image for the prefix matching in the radix attention.
input_ids
=
input_ids
,
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
positions
=
positions
,
input_ids
.
clamp_
(
min
=
0
,
max
=
self
.
config
.
vocab_size
-
1
)
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
image_embedding_func
=
self
.
get_image_feature
,
extend_start_loc_cpu
=
forward_batch
.
extend_start_loc
.
cpu
().
numpy
()
)
prefix_lens_cpu
=
forward_batch
.
extend_prefix_lens_cpu
for
i
,
image
in
enumerate
(
forward_batch
.
image_inputs
):
if
image
is
None
or
image
.
pixel_values
is
None
:
continue
start_idx
=
extend_start_loc_cpu
[
i
]
prefix_len
=
prefix_lens_cpu
[
i
]
pixel_values
=
image
.
pixel_values
.
clone
()
image_grid_thws
=
torch
.
tensor
(
np
.
array
(
image
.
image_grid_thws
),
device
=
"cuda"
)
image_offsets
=
image
.
image_offsets
image_input
=
Qwen2VLImageInputs
(
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thws
)
image_embeds
=
self
.
_process_image_input
(
image_input
)
image_embeds_offset
=
0
for
idx
,
image_offset
in
enumerate
(
image_offsets
):
if
image_offset
<
prefix_len
:
continue
num_image_tokens
=
self
.
calculate_num_image_tokens
(
image_grid_thws
[
idx
]
)
left_idx
=
start_idx
+
(
image_offset
-
prefix_len
+
1
)
right_idx
=
left_idx
+
num_image_tokens
inputs_embeds
[
left_idx
:
right_idx
]
=
image_embeds
[
image_embeds_offset
:
image_embeds_offset
+
num_image_tokens
]
image_embeds_offset
+=
num_image_tokens
input_ids
=
None
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
input_embeds
=
inputs_embeds
,
...
...
test/srt/test_vision_openai_server.py
View file @
11577ced
...
@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
...
@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
popen_launch_server
,
popen_launch_server
,
)
)
# image
IMAGE_MAN_IRONING_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
IMAGE_SGL_LOGO_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
# video
VIDEO_JOBS_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
# audio
AUDIO_TRUMP_SPEECH_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
AUDIO_BIRD_SONG_URL
=
"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class
TestOpenAIVisionServer
(
unittest
.
TestCase
):
class
TestOpenAIVisionServer
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
...
@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content"
:
[
"content"
:
[
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
},
{
{
"type"
:
"text"
,
"type"
:
"text"
,
...
@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content"
:
[
"content"
:
[
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
},
{
{
"type"
:
"text"
,
"type"
:
"text"
,
...
@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
},
},
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_SGL_LOGO_URL
},
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities"
:
"multi-images"
,
"modalities"
:
"multi-images"
,
},
},
{
{
...
@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
]
]
return
messages
return
messages
def
test_video_chat_completion
(
self
):
def
get_or_download_file
(
self
,
url
:
str
)
->
str
:
url
=
"https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache"
)
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache"
)
file_path
=
os
.
path
.
join
(
cache_dir
,
"jobs.mp4"
)
if
url
is
None
:
raise
ValueError
()
file_name
=
url
.
split
(
"/"
)[
-
1
]
file_path
=
os
.
path
.
join
(
cache_dir
,
file_name
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
file_path
):
if
not
os
.
path
.
exists
(
file_path
):
...
@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
with
open
(
file_path
,
"wb"
)
as
f
:
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
response
.
content
)
f
.
write
(
response
.
content
)
return
file_path
def
test_video_chat_completion
(
self
):
url
=
VIDEO_JOBS_URL
file_path
=
self
.
get_or_download_file
(
url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
...
@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"present"
in
video_response
"present"
in
video_response
or
"examine"
in
video_response
or
"examine"
in
video_response
or
"display"
in
video_response
or
"display"
in
video_response
or
"hold"
in
video_response
)
)
assert
"black"
in
video_response
or
"dark"
in
video_response
assert
"black"
in
video_response
or
"dark"
in
video_response
self
.
assertIsNotNone
(
video_response
)
self
.
assertIsNotNone
(
video_response
)
...
@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content"
:
[
"content"
:
[
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
},
{
{
"type"
:
"text"
,
"type"
:
"text"
,
...
@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
...
@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
content
.
append
(
content
.
append
(
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}
}
)
)
elif
image_id
==
1
:
elif
image_id
==
1
:
content
.
append
(
content
.
append
(
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_SGL_LOGO_URL
},
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
}
}
)
)
else
:
else
:
...
@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
...
@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
"content"
:
[
"content"
:
[
{
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
IMAGE_MAN_IRONING_URL
},
"url"
:
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
},
},
{
{
"type"
:
"text"
,
"type"
:
"text"
,
...
...
test/srt/test_vlm_accuracy.py
View file @
11577ced
...
@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
...
@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_image_inputs
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
...
@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
).
eval
()
).
eval
()
cls
.
model
.
to
(
cls
.
device
)
cls
.
model
.
to
(
cls
.
device
)
async
def
test_encode_output
(
self
):
async
def
test_vlm_embedding_output
(
self
):
"""
Compares the embedding output of vlm
"""
inputs
=
self
.
get_processor_output
()
inputs
=
self
.
get_processor_output
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# hf
model_inputs
=
{
model_inputs
=
{
"input_ids"
:
inputs
.
input_ids
,
"input_ids"
:
inputs
.
input_ids
,
"image_bound"
:
inputs
.
image_bound
,
"image_bound"
:
inputs
.
image_bound
,
...
@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
...
@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
)
)
hf_output
=
hf_output
.
squeeze
(
0
)
hf_output
=
hf_output
.
squeeze
(
0
)
with
torch
.
no_grad
():
# sglang
model
=
self
.
get_sglang_model
()
model
=
self
.
get_sglang_model
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
image_inputs
=
model
.
_parse_and_validate_inputs
(
sglang_output
=
embed_image_inputs
(
image_input
=
ImageInputs
(
pixel_values
=
inputs
[
"pixel_values"
][
0
],
tgt_sizes
=
inputs
[
"tgt_sizes"
][
0
],
),
input_ids
=
input_ids
,
input_ids
=
input_ids
,
**
{
input_embedding
=
model
.
get_input_embeddings
(),
"pixel_values"
:
[
inputs
[
"pixel_values"
]],
image_embedding_func
=
model
.
get_image_features
,
"tgt_sizes"
:
[
inputs
[
"tgt_sizes"
]],
placeholder_token_ids
=
[
"im_start_id"
:
self
.
tokenizer
.
im_start_id
,
self
.
processor
.
tokenizer
.
unk_token_id
,
"im_end_id"
:
self
.
tokenizer
.
im_end_id
,
],
"slice_start_id"
:
self
.
tokenizer
.
slice_start_id
,
"slice_end_id"
:
self
.
tokenizer
.
slice_end_id
,
},
)
(
sglang_output
,
_
)
=
model
.
get_embedding
(
input_ids
=
input_ids
,
image_inputs
=
image_inputs
)
)
self
.
compare_outputs
(
sglang_output
,
hf_output
)
self
.
compare_outputs
(
sglang_output
,
hf_output
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment