Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
681fdc26
Unverified
Commit
681fdc26
authored
May 24, 2025
by
Xinyuan Tong
Committed by
GitHub
May 24, 2025
Browse files
Refactor vlm embedding routine to use precomputed feature (#6543)
Signed-off-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
0d477880
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
285 additions
and
203 deletions
+285
-203
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+88
-38
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+8
-6
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+0
-7
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+0
-6
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+0
-6
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+1
-139
test/srt/test_vlm_input_format.py
test/srt/test_vlm_input_format.py
+187
-0
No files found.
python/sglang/srt/managers/mm_utils.py
View file @
681fdc26
...
...
@@ -252,40 +252,36 @@ def get_embedding_chunk(
return
embedding_chunk
,
start_index
,
end_index
def
get_embedding_and_mask
(
def
_get_precomputed_embedding
(
items
:
List
[
MultimodalDataItem
],
)
->
Optional
[
torch
.
Tensor
]:
"""
If all items have precomputed_features, return their concatenation.
If some but not all have precomputed_features, raise NotImplementedError.
If none have precomputed_features, return None.
"""
precomputed_features
=
[
item
.
precomputed_features
for
item
in
items
]
if
any
(
feature
is
not
None
for
feature
in
precomputed_features
):
if
not
all
(
feature
is
not
None
for
feature
in
precomputed_features
):
raise
NotImplementedError
(
"MM inputs where only some items are precomputed."
)
result
=
torch
.
concat
(
precomputed_features
)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result
=
result
.
reshape
(
-
1
,
result
.
shape
[
-
1
])
return
result
return
None
def
_get_chunked_prefill_embedding
(
data_embedding_func
:
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
],
embedding_items
:
List
[
MultimodalDataItem
],
placeholder_tensor
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
items_size
:
List
[
int
],
prefix_length
:
List
[
int
],
extend_length
:
List
[
int
],
items_offset_list
:
List
[
List
[
Tuple
[
int
,
int
]]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
Raises:
AssertionError: If the number of multimodal tokens in input_ids doesn't match
the number of tokens in the generated embeddings
"""
# 1. Get the embedding
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
)
->
Optional
[
torch
.
Tensor
]:
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
embedding_list
=
[]
for
i
in
range
(
len
(
items_size
)
-
1
):
if
items_size
[
i
]
==
items_size
[
i
+
1
]:
...
...
@@ -321,21 +317,28 @@ def get_embedding_and_mask(
embedding_cache
.
free
(
embedding_items_hash
)
embedding_list
.
append
(
embedding_per_req_chunk
)
if
len
(
embedding_list
)
==
0
:
return
None
,
None
embedding
=
torch
.
concat
(
embedding_list
,
dim
=
0
)
# 2. Check the embedding
num_mm_tokens_in_embedding
=
embedding
.
shape
[
0
]
special
_multimodal_mask
=
torch
.
isin
(
input_ids
,
placeholder_t
ensor
,
).
unsqueeze
(
-
1
)
return
None
return
torch
.
concat
(
embedding_list
,
dim
=
0
)
def
_get
_multimodal_mask
(
input_ids
:
torch
.
Tensor
,
placeholder_tensor
:
torch
.
Tensor
)
->
torch
.
T
ensor
:
return
torch
.
isin
(
input_ids
,
placeholder_tensor
).
unsqueeze
(
-
1
)
num_mm_tokens_in_input_ids
=
special_multimodal_mask
.
sum
().
item
()
def
_adjust_embedding_length
(
embedding
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
logger
,
)
->
torch
.
Tensor
:
num_mm_tokens_in_embedding
=
embedding
.
shape
[
0
]
num_mm_tokens_in_input_ids
=
mask
.
sum
().
item
()
if
num_mm_tokens_in_input_ids
!=
num_mm_tokens_in_embedding
:
logger
.
warning
(
f
"Number of tokens in multimodal embedding does not match those in the input text. "
f
"Got
{
num_mm_tokens_in_input_ids
}
tokens in the text but
{
num_mm_tokens_in_embedding
}
"
"tokens from multimodal embeddings."
f
"tokens from multimodal embeddings."
)
if
num_mm_tokens_in_input_ids
<
num_mm_tokens_in_embedding
:
chunked_prefill_size
=
global_server_args_dict
[
"chunked_prefill_size"
]
...
...
@@ -353,7 +356,54 @@ def get_embedding_and_mask(
raise
RuntimeError
(
f
"Insufficient multimodal embedding length:
{
num_mm_tokens_in_input_ids
=
}
vs
{
num_mm_tokens_in_embedding
=
}
. This is an internal error"
)
return
embedding
def
get_embedding_and_mask
(
data_embedding_func
:
Callable
[[
List
[
MultimodalDataItem
]],
torch
.
Tensor
],
embedding_items
:
List
[
MultimodalDataItem
],
placeholder_tensor
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
items_size
:
List
[
int
],
prefix_length
:
List
[
int
],
extend_length
:
List
[
int
],
items_offset_list
:
List
[
List
[
Tuple
[
int
,
int
]]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
"""
# 1. Get embedding
embedding
=
_get_precomputed_embedding
(
embedding_items
)
if
embedding
is
None
:
embedding
=
_get_chunked_prefill_embedding
(
data_embedding_func
,
embedding_items
,
items_size
,
prefix_length
,
extend_length
,
items_offset_list
,
)
if
embedding
is
None
:
return
None
,
None
# 2. Get mask
special_multimodal_mask
=
_get_multimodal_mask
(
input_ids
,
placeholder_tensor
)
# 3. Adjust embedding length if needed
embedding
=
_adjust_embedding_length
(
embedding
,
special_multimodal_mask
,
logger
)
return
embedding
,
special_multimodal_mask
...
...
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
681fdc26
...
...
@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if
base_output
.
images
:
if
images_are_preprocessed
:
image_grid_thw
=
torch
.
concat
(
[
torch
.
as_tensor
(
item
.
image_grid_thws
)
for
item
in
base_output
.
images
]
)
all_image_grid_thws
=
[
item
.
image_grid_thws
for
item
in
base_output
.
images
if
item
.
image_grid_thws
is
not
None
]
all_pixel_values
=
[
item
.
pixel_values
for
item
in
base_output
.
images
...
...
@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
for
item
in
base_output
.
images
if
item
.
precomputed_features
is
not
None
]
image_grid_thw
=
(
torch
.
concat
(
all_image_grid_thws
)
if
all_image_grid_thws
else
None
)
pixel_values
=
(
torch
.
concat
(
all_pixel_values
)
if
all_pixel_values
else
None
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
681fdc26
...
...
@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if
any
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
if
not
all
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
raise
NotImplementedError
(
"MM inputs where only some items are precomputed."
)
return
torch
.
concat
([
item
.
precomputed_features
for
item
in
items
])
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
vision_outputs_list
=
[]
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
681fdc26
...
...
@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
if
any
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
if
not
all
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
raise
NotImplementedError
(
"MM inputs where only some items are precomputed."
)
return
torch
.
concat
([
item
.
precomputed_features
for
item
in
items
])
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
681fdc26
...
...
@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
if
any
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
if
not
all
(
item
.
precomputed_features
is
not
None
for
item
in
items
):
raise
NotImplementedError
(
"MM inputs where only some items are precomputed."
)
return
torch
.
concat
([
item
.
precomputed_features
for
item
in
items
])
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
...
...
test/srt/run_suite.py
View file @
681fdc26
...
...
@@ -81,7 +81,7 @@ suites = {
TestFile
(
"test_update_weights_from_tensor.py"
,
48
),
TestFile
(
"test_vertex_endpoint.py"
,
31
),
TestFile
(
"test_vision_chunked_prefill.py"
,
175
),
TestFile
(
"test_vlm_
accuracy
.py"
,
6
0
),
TestFile
(
"test_vlm_
input_format
.py"
,
30
0
),
TestFile
(
"test_vision_openai_server_a.py"
,
700
),
TestFile
(
"test_vision_openai_server_b.py"
,
700
),
TestFile
(
"test_w8a8_quantization.py"
,
46
),
...
...
test/srt/test_vlm_accuracy.py
View file @
681fdc26
...
...
@@ -10,15 +10,8 @@ import requests
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
(
AutoModel
,
AutoProcessor
,
AutoTokenizer
,
Gemma3ForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
)
from
transformers
import
AutoModel
,
AutoProcessor
,
AutoTokenizer
from
sglang
import
Engine
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
,
init_embedding_cache
...
...
@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
def
setUpClass
(
cls
):
cls
.
image_url
=
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
cls
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
cls
.
model_path
=
""
cls
.
chat_template
=
""
cls
.
processor
=
""
response
=
requests
.
get
(
cls
.
image_url
)
cls
.
main_image
=
Image
.
open
(
BytesIO
(
response
.
content
))
...
...
@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
)
self
.
compare_outputs
(
sglang_output
,
hf_output
)
class
TestQwenVLUnderstandsImage
(
VisionLLMLogitsBase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model_path
=
"Qwen/Qwen2.5-VL-3B-Instruct"
cls
.
chat_template
=
"qwen2-vl"
cls
.
processor
=
AutoProcessor
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
,
use_fast
=
True
)
cls
.
visual
=
(
Qwen2_5_VLForConditionalGeneration
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
)
.
eval
()
.
visual
.
to
(
cls
.
device
)
)
def
setUp
(
self
):
self
.
engine
=
Engine
(
model_path
=
self
.
model_path
,
chat_template
=
self
.
chat_template
,
device
=
self
.
device
.
type
,
mem_fraction_static
=
0.8
,
)
def
tearDown
(
self
):
self
.
engine
.
shutdown
()
async
def
test_qwen_vl_understands_image
(
self
):
req
=
self
.
get_completion_request
()
conv
=
generate_chat_conv
(
req
,
template_name
=
self
.
chat_template
)
text
=
conv
.
get_prompt
()
output
=
await
self
.
engine
.
async_generate
(
prompt
=
text
,
image_data
=
[
self
.
main_image
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
async
def
test_qwen_vl_understands_precomputed_features
(
self
):
req
=
self
.
get_completion_request
()
processor_output
=
self
.
get_processor_output
(
req
=
req
)
with
torch
.
inference_mode
():
precomputed_features
=
self
.
visual
(
processor_output
[
"pixel_values"
],
processor_output
[
"image_grid_thw"
]
)
output
=
await
self
.
engine
.
async_generate
(
input_ids
=
processor_output
[
"input_ids"
][
0
].
detach
().
cpu
().
tolist
(),
image_data
=
[
dict
(
modality
=
"IMAGE"
,
image_grid_thws
=
processor_output
[
"image_grid_thw"
],
precomputed_features
=
precomputed_features
,
)
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
class
TestGemmaUnderstandsImage
(
VisionLLMLogitsBase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model_path
=
"google/gemma-3-4b-it"
cls
.
chat_template
=
"gemma-it"
cls
.
processor
=
AutoProcessor
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
,
use_fast
=
True
)
model
=
Gemma3ForConditionalGeneration
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
)
cls
.
vision_tower
=
model
.
vision_tower
.
eval
().
to
(
cls
.
device
)
cls
.
mm_projector
=
model
.
multi_modal_projector
.
eval
().
to
(
cls
.
device
)
@
classmethod
def
visual
(
cls
,
pixel_values
):
vision_outputs
=
cls
.
vision_tower
(
pixel_values
=
pixel_values
).
last_hidden_state
image_features
=
cls
.
mm_projector
(
vision_outputs
)
return
image_features
def
setUp
(
self
):
self
.
engine
=
Engine
(
model_path
=
self
.
model_path
,
chat_template
=
self
.
chat_template
,
device
=
self
.
device
.
type
,
mem_fraction_static
=
0.5
,
enable_multimodal
=
True
,
)
def
tearDown
(
self
):
self
.
engine
.
shutdown
()
async
def
test_gemma_understands_image
(
self
):
req
=
self
.
get_completion_request
()
conv
=
generate_chat_conv
(
req
,
template_name
=
self
.
chat_template
)
text
=
conv
.
get_prompt
()
output
=
await
self
.
engine
.
async_generate
(
prompt
=
text
,
image_data
=
[
self
.
main_image
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
async
def
test_gemma_understands_precomputed_features
(
self
):
req
=
self
.
get_completion_request
()
processor_output
=
self
.
get_processor_output
(
req
=
req
)
with
torch
.
inference_mode
():
precomputed_features
=
self
.
visual
(
processor_output
[
"pixel_values"
])
output
=
await
self
.
engine
.
async_generate
(
input_ids
=
processor_output
[
"input_ids"
][
0
].
detach
().
cpu
().
tolist
(),
image_data
=
[
dict
(
modality
=
"IMAGE"
,
precomputed_features
=
precomputed_features
,
)
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_vlm_input_format.py
0 → 100644
View file @
681fdc26
import
json
import
unittest
from
io
import
BytesIO
from
typing
import
Optional
import
requests
import
torch
from
PIL
import
Image
from
transformers
import
(
AutoProcessor
,
Gemma3ForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang
import
Engine
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.openai_api.protocol
import
ChatCompletionRequest
TEST_IMAGE_URL
=
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
class
VLMInputTestBase
:
model_path
=
None
chat_template
=
None
processor
=
None
visual
=
None
# Should be a callable for precomputed features
@
classmethod
def
setUpClass
(
cls
):
assert
cls
.
model_path
is
not
None
,
"Set model_path in subclass"
assert
cls
.
chat_template
is
not
None
,
"Set chat_template in subclass"
cls
.
image_url
=
TEST_IMAGE_URL
cls
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
response
=
requests
.
get
(
cls
.
image_url
)
cls
.
main_image
=
Image
.
open
(
BytesIO
(
response
.
content
))
cls
.
processor
=
AutoProcessor
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
,
use_fast
=
True
)
cls
.
_init_visual
()
@
classmethod
def
_init_visual
(
cls
):
"""Override in subclass to set up cls.visual as a callable for precomputed features."""
raise
NotImplementedError
def
setUp
(
self
):
self
.
engine
=
Engine
(
model_path
=
self
.
model_path
,
chat_template
=
self
.
chat_template
,
device
=
self
.
device
.
type
,
mem_fraction_static
=
0.8
,
enable_multimodal
=
True
,
disable_cuda_graph
=
True
,
)
def
tearDown
(
self
):
self
.
engine
.
shutdown
()
def
get_completion_request
(
self
)
->
ChatCompletionRequest
:
json_structure
=
{
"model"
:
self
.
model_path
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
self
.
image_url
}},
{
"type"
:
"text"
,
"text"
:
"What's in this picture?"
},
],
}
],
}
json_str
=
json
.
dumps
(
json_structure
)
return
ChatCompletionRequest
.
model_validate_json
(
json_str
)
def
get_processor_output
(
self
,
req
:
Optional
[
ChatCompletionRequest
]
=
None
):
if
req
is
None
:
req
=
self
.
get_completion_request
()
conv
=
generate_chat_conv
(
req
,
template_name
=
self
.
chat_template
)
text
=
conv
.
get_prompt
()
# Process inputs using processor
inputs
=
self
.
processor
(
text
=
[
text
],
images
=
[
self
.
main_image
],
return_tensors
=
"pt"
,
).
to
(
self
.
device
)
return
inputs
async
def
test_understands_image
(
self
):
req
=
self
.
get_completion_request
()
conv
=
generate_chat_conv
(
req
,
template_name
=
self
.
chat_template
)
text
=
conv
.
get_prompt
()
output
=
await
self
.
engine
.
async_generate
(
prompt
=
text
,
image_data
=
[
self
.
main_image
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
async
def
test_understands_precomputed_features
(
self
):
req
=
self
.
get_completion_request
()
processor_output
=
self
.
get_processor_output
(
req
=
req
)
with
torch
.
inference_mode
():
precomputed_features
=
self
.
__class__
.
visual
(
processor_output
)
output
=
await
self
.
engine
.
async_generate
(
input_ids
=
processor_output
[
"input_ids"
][
0
].
detach
().
cpu
().
tolist
(),
image_data
=
[
self
.
_precomputed_image_data
(
processor_output
,
precomputed_features
)
],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
async
def
test_understands_pixel_values
(
self
):
req
=
self
.
get_completion_request
()
processor_output
=
self
.
get_processor_output
(
req
=
req
)
output
=
await
self
.
engine
.
async_generate
(
input_ids
=
processor_output
[
"input_ids"
][
0
].
detach
().
cpu
().
tolist
(),
image_data
=
[
self
.
_pixel_values_image_data
(
processor_output
)],
sampling_params
=
dict
(
temperature
=
0.0
),
)
self
.
assertIn
(
"taxi"
,
output
[
"text"
].
lower
())
def
_precomputed_image_data
(
self
,
processor_output
,
precomputed_features
):
"""This should not be overridden."""
return
dict
(
modality
=
"IMAGE"
,
precomputed_features
=
precomputed_features
,
)
def
_pixel_values_image_data
(
self
,
processor_output
):
"""Override in subclass to pass the correct set of arguments."""
raise
NotImplementedError
class
TestQwenVLUnderstandsImage
(
VLMInputTestBase
,
unittest
.
IsolatedAsyncioTestCase
):
model_path
=
"Qwen/Qwen2.5-VL-3B-Instruct"
chat_template
=
"qwen2-vl"
@
classmethod
def
_init_visual
(
cls
):
cls
.
visual_model
=
(
Qwen2_5_VLForConditionalGeneration
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
)
.
eval
()
.
visual
.
to
(
cls
.
device
)
)
cls
.
visual
=
lambda
processor_output
:
cls
.
visual_model
(
processor_output
[
"pixel_values"
],
processor_output
[
"image_grid_thw"
]
)
def
_pixel_values_image_data
(
self
,
processor_output
):
return
dict
(
modality
=
"IMAGE"
,
image_grid_thws
=
processor_output
[
"image_grid_thw"
],
pixel_values
=
processor_output
[
"pixel_values"
],
)
class
TestGemmaUnderstandsImage
(
VLMInputTestBase
,
unittest
.
IsolatedAsyncioTestCase
):
model_path
=
"google/gemma-3-4b-it"
chat_template
=
"gemma-it"
@
classmethod
def
_init_visual
(
cls
):
model
=
Gemma3ForConditionalGeneration
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
)
cls
.
vision_tower
=
model
.
vision_tower
.
eval
().
to
(
cls
.
device
)
cls
.
mm_projector
=
model
.
multi_modal_projector
.
eval
().
to
(
cls
.
device
)
cls
.
visual
=
lambda
processor_output
:
cls
.
mm_projector
(
cls
.
vision_tower
(
pixel_values
=
processor_output
[
"pixel_values"
]
).
last_hidden_state
)
def
_pixel_values_image_data
(
self
,
processor_output
):
return
dict
(
modality
=
"IMAGE"
,
pixel_values
=
processor_output
[
"pixel_values"
][
0
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment