Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eed2f463
Unverified
Commit
eed2f463
authored
Jul 27, 2025
by
Isotr0py
Committed by
GitHub
Jul 26, 2025
Browse files
[VLM] Support HF format Phi-4-MM model (#17121)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
20950b29
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1847 additions
and
5 deletions
+1847
-5
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
examples/offline_inference/audio_language.py
examples/offline_inference/audio_language.py
+32
-0
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+36
-0
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+35
-0
tests/models/multimodal/generation/test_phi4_multimodal.py
tests/models/multimodal/generation/test_phi4_multimodal.py
+252
-0
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+31
-3
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/models/phi4_multimodal.py
vllm/model_executor/models/phi4_multimodal.py
+1455
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+1
-1
No files found.
docs/models/supported_models.md
View file @
eed2f463
...
...
@@ -614,6 +614,7 @@ Specified using `--task generate`.
|
`PaliGemmaForConditionalGeneration`
| PaliGemma, PaliGemma 2 | T + I
<sup>
E
</sup>
|
`google/paligemma-3b-pt-224`
,
`google/paligemma-3b-mix-224`
,
`google/paligemma2-3b-ft-docci-448`
, etc. | | ✅︎ | ⚠️ |
|
`Phi3VForCausalLM`
| Phi-3-Vision, Phi-3.5-Vision | T + I
<sup>
E+
</sup>
|
`microsoft/Phi-3-vision-128k-instruct`
,
`microsoft/Phi-3.5-vision-instruct`
, etc. | | ✅︎ | ✅︎ |
|
`Phi4MMForCausalLM`
| Phi-4-multimodal | T + I
<sup>
+
</sup>
/ T + A
<sup>
+
</sup>
/ I
<sup>
+
</sup>
+ A
<sup>
+
</sup>
|
`microsoft/Phi-4-multimodal-instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Phi4MultimodalForCausalLM`
| Phi-4-multimodal (HF Transformers) | T + I
<sup>
+
</sup>
/ T + A
<sup>
+
</sup>
/ I
<sup>
+
</sup>
+ A
<sup>
+
</sup>
|
`microsoft/Phi-4-multimodal-instruct`
(with revision
`refs/pr/70`
), etc. | ✅︎ | ✅︎ | ✅︎ |
|
`PixtralForConditionalGeneration`
| Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I
<sup>
+
</sup>
|
`mistralai/Mistral-Small-3.1-24B-Instruct-2503`
,
`mistralai/Pixtral-12B-2409`
, etc. | | ✅︎ | ✅︎ |
|
`QwenVLForConditionalGeneration`
<sup>
^
</sup>
| Qwen-VL | T + I
<sup>
E+
</sup>
|
`Qwen/Qwen-VL`
,
`Qwen/Qwen-VL-Chat`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen2AudioForConditionalGeneration`
| Qwen2-Audio | T + A
<sup>
+
</sup>
|
`Qwen/Qwen2-Audio-7B-Instruct`
| | ✅︎ | ✅︎ |
...
...
examples/offline_inference/audio_language.py
View file @
eed2f463
...
...
@@ -190,6 +190,37 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
)
def
run_phi4_multimodal
(
question
:
str
,
audio_count
:
int
)
->
ModelRequestData
:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
,
revision
=
"refs/pr/70"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path
=
os
.
path
.
join
(
model_path
,
"speech-lora"
)
placeholders
=
"<|audio|>"
*
audio_count
prompts
=
f
"<|user|>
{
placeholders
}{
question
}
<|end|><|assistant|>"
engine_args
=
EngineArgs
(
model
=
model_path
,
max_model_len
=
12800
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_lora_rank
=
320
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
},
)
return
ModelRequestData
(
engine_args
=
engine_args
,
prompt
=
prompts
,
lora_requests
=
[
LoRARequest
(
"speech"
,
1
,
speech_lora_path
)],
)
# Qwen2-Audio
def
run_qwen2_audio
(
question
:
str
,
audio_count
:
int
)
->
ModelRequestData
:
model_name
=
"Qwen/Qwen2-Audio-7B-Instruct"
...
...
@@ -303,6 +334,7 @@ model_example_map = {
"granite_speech"
:
run_granite_speech
,
"minicpmo"
:
run_minicpmo
,
"phi4_mm"
:
run_phi4mm
,
"phi4_multimodal"
:
run_phi4_multimodal
,
"qwen2_audio"
:
run_qwen2_audio
,
"qwen2_5_omni"
:
run_qwen2_5_omni
,
"ultravox"
:
run_ultravox
,
...
...
examples/offline_inference/vision_language.py
View file @
eed2f463
...
...
@@ -1097,6 +1097,41 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
)
# HF format Phi-4-multimodal-instruct
def
run_phi4_multimodal
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
"""
assert
modality
==
"image"
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
,
revision
=
"refs/pr/70"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path
=
os
.
path
.
join
(
model_path
,
"vision-lora"
)
prompts
=
[
f
"<|user|><|image|>
{
question
}
<|end|><|assistant|>"
for
question
in
questions
]
engine_args
=
EngineArgs
(
model
=
model_path
,
max_model_len
=
5120
,
max_num_seqs
=
2
,
max_num_batched_tokens
=
12800
,
enable_lora
=
True
,
max_lora_rank
=
320
,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs
=
{
"dynamic_hd"
:
16
},
limit_mm_per_prompt
=
{
"image"
:
1
},
)
return
ModelRequestData
(
engine_args
=
engine_args
,
prompts
=
prompts
,
lora_requests
=
[
LoRARequest
(
"vision"
,
1
,
vision_lora_path
)],
)
# Pixtral HF-format
def
run_pixtral_hf
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
...
...
@@ -1356,6 +1391,7 @@ model_example_map = {
"paligemma2"
:
run_paligemma2
,
"phi3_v"
:
run_phi3v
,
"phi4_mm"
:
run_phi4mm
,
"phi4_multimodal"
:
run_phi4_multimodal
,
"pixtral_hf"
:
run_pixtral_hf
,
"qwen_vl"
:
run_qwen_vl
,
"qwen2_vl"
:
run_qwen2_vl
,
...
...
examples/offline_inference/vision_language_multi_image.py
View file @
eed2f463
...
...
@@ -760,6 +760,40 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
)
def
load_phi4_multimodal
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process multi images inputs.
"""
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
,
revision
=
"refs/pr/70"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path
=
os
.
path
.
join
(
model_path
,
"vision-lora"
)
engine_args
=
EngineArgs
(
model
=
model_path
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
enable_lora
=
True
,
max_lora_rank
=
320
,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs
=
{
"dynamic_hd"
:
4
},
)
placeholders
=
"<|image|>"
*
len
(
image_urls
)
prompt
=
f
"<|user|>
{
placeholders
}{
question
}
<|end|><|assistant|>"
return
ModelRequestData
(
engine_args
=
engine_args
,
prompt
=
prompt
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
lora_requests
=
[
LoRARequest
(
"vision"
,
1
,
vision_lora_path
)],
)
def
load_qwen_vl_chat
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
model_name
=
"Qwen/Qwen-VL-Chat"
engine_args
=
EngineArgs
(
...
...
@@ -988,6 +1022,7 @@ model_example_map = {
"ovis"
:
load_ovis
,
"phi3_v"
:
load_phi3v
,
"phi4_mm"
:
load_phi4mm
,
"phi4_multimodal"
:
load_phi4_multimodal
,
"pixtral_hf"
:
load_pixtral_hf
,
"qwen_vl_chat"
:
load_qwen_vl_chat
,
"qwen2_vl"
:
load_qwen2_vl
,
...
...
tests/models/multimodal/generation/test_phi4_multimodal.py
0 → 100644
View file @
eed2f463
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
collections.abc
import
Sequence
from
typing
import
Optional
import
librosa
import
pytest
from
huggingface_hub
import
snapshot_download
from
vllm.assets.image
import
ImageAsset
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.platforms
import
current_platform
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptAudioInput
,
PromptImageInput
,
VllmRunner
)
from
....utils
import
large_gpu_test
from
...utils
import
check_logprobs_close
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<|user|>
\n
<|image|>
\n
What's the content of the image?<|end|>
\n
<|assistant|>
\n
"
,
# noqa: E501
"cherry_blossom"
:
"<|user|>
\n
<|image|>
\n
Please infer the season with reason in details.<|end|>
\n
<|assistant|>
\n
"
,
# noqa: E501
})
HF_MULTIIMAGE_IMAGE_PROMPT
=
"<|user|>
\n
<|image|>
\n
<|image|>
\n
Describe these images.<|end|>
\n
<|assistant|>
\n
"
# noqa: E501
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
,
revision
=
"refs/pr/70"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path
=
os
.
path
.
join
(
model_path
,
"vision-lora"
)
speech_question
=
os
.
path
.
join
(
model_path
,
"examples"
,
"what_is_shown_in_this_image.wav"
)
models
=
[
model_path
]
target_dtype
=
"half"
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if
current_platform
.
is_rocm
():
os
.
environ
[
"VLLM_USE_TRITON_FLASH_ATTN"
]
=
"0"
def
run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
inputs
:
Sequence
[
tuple
[
list
[
str
],
PromptImageInput
,
Optional
[
PromptAudioInput
]]],
model
:
str
,
*
,
max_model_len
:
int
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
mm_limit
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
task
=
"generate"
,
max_model_len
=
max_model_len
,
max_num_seqs
=
2
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
mm_limit
},
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enable_lora
=
True
,
max_lora_rank
=
320
,
gpu_memory_utilization
=
0.8
,
# set to 0.8 to avoid OOM in CI
enforce_eager
=
True
,
trust_remote_code
=
False
,
)
as
vllm_model
:
lora_request
=
LoRARequest
(
"vision"
,
1
,
vision_lora_path
)
vllm_outputs_per_case
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
,
audios
=
audios
,
lora_request
=
lora_request
)
for
prompts
,
images
,
audios
in
inputs
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_model
.
model
.
load_adapter
(
vision_lora_path
,
adapter_name
=
"vision"
,
)
hf_processor
=
hf_model
.
processor
eos_token_id
=
hf_processor
.
tokenizer
.
eos_token_id
hf_outputs_per_case
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
,
audios
=
audios
,
eos_token_id
=
eos_token_id
)
for
prompts
,
images
,
audios
in
inputs
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_case
,
vllm_outputs_per_case
):
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
[
12800
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_model_len
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
None
,
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
model
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
mm_limit
=
1
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
# [],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
[
25600
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_multi_images_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_model_len
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_case
=
[
(
[
HF_MULTIIMAGE_IMAGE_PROMPT
for
_
in
size_factors
],
[[
rescale_image_size
(
image
,
factor
)
for
image
in
images
]
for
factor
in
size_factors
],
None
,
),
]
run_test
(
hf_runner
,
vllm_runner
,
inputs_per_case
,
model
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
mm_limit
=
2
,
tensor_parallel_size
=
1
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
[
12800
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_vision_speech_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
:
str
,
max_model_len
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
# use the example speech question so that the model outputs are reasonable
audio
=
librosa
.
load
(
speech_question
,
sr
=
16000
)
image
=
ImageAsset
(
"cherry_blossom"
).
pil_image
.
convert
(
"RGB"
)
inputs_vision_speech
=
[
(
[
"<|user|><|image|><|audio|><|end|><|assistant|>"
],
[
image
],
[
audio
],
),
]
run_test
(
hf_runner
,
vllm_runner
,
inputs_vision_speech
,
model
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
mm_limit
=
1
,
tensor_parallel_size
=
1
,
)
tests/models/multimodal/processing/test_common.py
View file @
eed2f463
...
...
@@ -41,12 +41,18 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
def
_test_processing_correctness
(
model_id
:
str
,
model_id
_or_arch
:
str
,
hit_rate
:
float
,
num_batches
:
int
,
simplify_rate
:
float
,
):
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id
)
if
model_id_or_arch
in
HF_EXAMPLE_MODELS
.
get_supported_archs
():
# Use model architecture to get the default model id
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_id_or_arch
)
model_id
=
model_info
.
default
else
:
model_info
=
HF_EXAMPLE_MODELS
.
find_hf_info
(
model_id_or_arch
)
model_id
=
model_id_or_arch
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
)
...
...
@@ -58,7 +64,7 @@ def _test_processing_correctness(
trust_remote_code
=
model_info
.
trust_remote_code
,
seed
=
0
,
dtype
=
"auto"
,
revision
=
None
,
revision
=
model_info
.
revision
,
hf_overrides
=
model_info
.
hf_overrides
,
)
...
...
@@ -331,6 +337,28 @@ def test_processing_correctness(
)
# Phi4MultimodalForCausalLM share same model repo with original format
# Phi4MMForCausalLM, so we add it as a separate test case
# Remove this test after conversion PR merged:
# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
[
"Phi4MultimodalForCausalLM"
])
@
pytest
.
mark
.
parametrize
(
"hit_rate"
,
[
0.3
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"num_batches"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"simplify_rate"
,
[
1.0
])
def
test_processing_correctness_phi4_multimodal
(
model_arch
:
str
,
hit_rate
:
float
,
num_batches
:
int
,
simplify_rate
:
float
,
):
_test_processing_correctness
(
model_arch
,
hit_rate
=
hit_rate
,
num_batches
=
num_batches
,
simplify_rate
=
simplify_rate
,
)
def
_assert_inputs_equal
(
a
:
MultiModalInputs
,
b
:
MultiModalInputs
,
...
...
tests/models/registry.py
View file @
eed2f463
...
...
@@ -433,6 +433,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"1.6-gemma"
:
"AIDC-AI/Ovis1.6-Gemma2-9B"
}),
# noqa: E501
"Phi4MMForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-4-multimodal-instruct"
,
trust_remote_code
=
True
),
"Phi4MultimodalForCausalLM"
:
_HfExamplesInfo
(
"microsoft/Phi-4-multimodal-instruct"
,
# noqa: E501
revision
=
"refs/pr/70"
),
"PixtralForConditionalGeneration"
:
_HfExamplesInfo
(
"mistralai/Pixtral-12B-2409"
,
# noqa: E501
tokenizer_mode
=
"mistral"
),
"QwenVLForConditionalGeneration"
:
_HfExamplesInfo
(
"Qwen/Qwen-VL"
,
...
...
vllm/model_executor/models/phi4_multimodal.py
0 → 100644
View file @
eed2f463
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
eed2f463
...
...
@@ -223,6 +223,8 @@ _MULTIMODAL_MODELS = {
"Ovis"
:
(
"ovis"
,
"Ovis"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
# noqa: E501
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Phi4MMForCausalLM"
:
(
"phi4mm"
,
"Phi4MMForCausalLM"
),
"Phi4MultimodalForCausalLM"
:
(
"phi4_multimodal"
,
"Phi4MultimodalForCausalLM"
),
# noqa: E501
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
),
# noqa: E501
"QwenVLForConditionalGeneration"
:
(
"qwen_vl"
,
"QwenVLForConditionalGeneration"
),
# noqa: E501
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
# noqa: E501
...
...
@@ -231,7 +233,6 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"Phi4MMForCausalLM"
:
(
"phi4mm"
,
"Phi4MMForCausalLM"
),
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
(
"qwen2_vl"
,
"Tarsier2ForConditionalGeneration"
),
# noqa: E501
"VoxtralForConditionalGeneration"
:
(
"voxtral"
,
"VoxtralForConditionalGeneration"
),
# noqa: E501
...
...
vllm/transformers_utils/tokenizer.py
View file @
eed2f463
...
...
@@ -295,7 +295,7 @@ def cached_tokenizer_from_config(
return
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
tokenizer_
revision
=
model_config
.
tokenizer_revision
,
revision
=
model_config
.
tokenizer_revision
,
trust_remote_code
=
model_config
.
trust_remote_code
,
**
kwargs
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment