Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
952a0749
Unverified
Commit
952a0749
authored
Mar 08, 2025
by
Jee Jee Li
Committed by
GitHub
Mar 07, 2025
Browse files
[Misc] Add Phi4-MM example (#14343)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
d0feea31
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
131 additions
and
7 deletions
+131
-7
examples/offline_inference/audio_language.py
examples/offline_inference/audio_language.py
+38
-0
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+38
-0
examples/offline_inference/vision_language_multi_image.py
examples/offline_inference/vision_language_multi_image.py
+44
-0
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+11
-7
No files found.
examples/offline_inference/audio_language.py
View file @
952a0749
...
@@ -6,10 +6,14 @@ with the correct prompt format on audio language models.
...
@@ -6,10 +6,14 @@ with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
on HuggingFace model repository.
"""
"""
import
os
from
huggingface_hub
import
snapshot_download
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.audio
import
AudioAsset
from
vllm.assets.audio
import
AudioAsset
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
audio_assets
=
[
AudioAsset
(
"mary_had_lamb"
),
AudioAsset
(
"winning_call"
)]
audio_assets
=
[
AudioAsset
(
"mary_had_lamb"
),
AudioAsset
(
"winning_call"
)]
...
@@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
...
@@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
# Phi-4-multimodal-instruct
def
run_phi4mm
(
questions
:
str
,
audio_count
:
int
):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process audio inputs.
"""
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path
=
os
.
path
.
join
(
model_path
,
"speech-lora"
)
placeholders
=
""
.
join
([
f
"<|audio_
{
i
+
1
}
|>"
for
i
in
range
(
audio_count
)])
prompts
=
f
"<|user|>
{
placeholders
}{
questions
}
<|end|><|assistant|>"
llm
=
LLM
(
model
=
model_path
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_lora_rank
=
320
,
lora_extra_vocab_size
=
0
,
)
lora_request
=
LoRARequest
(
"speech"
,
1
,
speech_lora_path
)
# To maintain code compatibility in this script, we add LoRA here.
llm
.
llm_engine
.
add_lora
(
lora_request
=
lora_request
)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids
=
None
return
llm
,
prompts
,
stop_token_ids
# Qwen2-Audio
# Qwen2-Audio
def
run_qwen2_audio
(
question
:
str
,
audio_count
:
int
):
def
run_qwen2_audio
(
question
:
str
,
audio_count
:
int
):
model_name
=
"Qwen/Qwen2-Audio-7B-Instruct"
model_name
=
"Qwen/Qwen2-Audio-7B-Instruct"
...
@@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):
...
@@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):
model_example_map
=
{
model_example_map
=
{
"minicpmo"
:
run_minicpmo
,
"minicpmo"
:
run_minicpmo
,
"phi4_mm"
:
run_phi4mm
,
"qwen2_audio"
:
run_qwen2_audio
,
"qwen2_audio"
:
run_qwen2_audio
,
"ultravox"
:
run_ultravox
,
"ultravox"
:
run_ultravox
,
"whisper"
:
run_whisper
,
"whisper"
:
run_whisper
,
...
...
examples/offline_inference/vision_language.py
View file @
952a0749
...
@@ -6,13 +6,16 @@ the correct prompt format on vision language models for text generation.
...
@@ -6,13 +6,16 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
on HuggingFace model repository.
"""
"""
import
os
import
random
import
random
from
huggingface_hub
import
snapshot_download
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.assets.video
import
VideoAsset
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
...
@@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
...
@@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
return
llm
,
prompts
,
stop_token_ids
return
llm
,
prompts
,
stop_token_ids
# Phi-4-multimodal-instruct
def
run_phi4mm
(
questions
:
list
[
str
],
modality
:
str
):
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process image inputs.
"""
assert
modality
==
"image"
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path
=
os
.
path
.
join
(
model_path
,
"vision-lora"
)
prompts
=
[
f
"<|user|><|image_1|>
{
question
}
<|end|><|assistant|>"
for
question
in
questions
]
llm
=
LLM
(
model
=
model_path
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_lora_rank
=
320
,
lora_extra_vocab_size
=
0
,
)
lora_request
=
LoRARequest
(
"vision"
,
1
,
vision_lora_path
)
# To maintain code compatibility in this script, we add LoRA here.
llm
.
llm_engine
.
add_lora
(
lora_request
=
lora_request
)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
stop_token_ids
=
None
return
llm
,
prompts
,
stop_token_ids
# Pixtral HF-format
# Pixtral HF-format
def
run_pixtral_hf
(
questions
:
list
[
str
],
modality
:
str
):
def
run_pixtral_hf
(
questions
:
list
[
str
],
modality
:
str
):
assert
modality
==
"image"
assert
modality
==
"image"
...
@@ -644,6 +681,7 @@ model_example_map = {
...
@@ -644,6 +681,7 @@ model_example_map = {
"paligemma"
:
run_paligemma
,
"paligemma"
:
run_paligemma
,
"paligemma2"
:
run_paligemma2
,
"paligemma2"
:
run_paligemma2
,
"phi3_v"
:
run_phi3v
,
"phi3_v"
:
run_phi3v
,
"phi4_mm"
:
run_phi4mm
,
"pixtral_hf"
:
run_pixtral_hf
,
"pixtral_hf"
:
run_pixtral_hf
,
"qwen_vl"
:
run_qwen_vl
,
"qwen_vl"
:
run_qwen_vl
,
"qwen2_vl"
:
run_qwen2_vl
,
"qwen2_vl"
:
run_qwen2_vl
,
...
...
examples/offline_inference/vision_language_multi_image.py
View file @
952a0749
...
@@ -4,13 +4,16 @@ This example shows how to use vLLM for running offline inference with
...
@@ -4,13 +4,16 @@ This example shows how to use vLLM for running offline inference with
multi-image input on vision language models for text generation,
multi-image input on vision language models for text generation,
using the chat template defined by the model.
using the chat template defined by the model.
"""
"""
import
os
from
argparse
import
Namespace
from
argparse
import
Namespace
from
typing
import
NamedTuple
,
Optional
from
typing
import
NamedTuple
,
Optional
from
huggingface_hub
import
snapshot_download
from
PIL.Image
import
Image
from
PIL.Image
import
Image
from
transformers
import
AutoProcessor
,
AutoTokenizer
from
transformers
import
AutoProcessor
,
AutoTokenizer
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.utils
import
fetch_image
from
vllm.multimodal.utils
import
fetch_image
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
...
@@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
)
)
def
load_phi4mm
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
"""
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
show how to process multi images inputs.
"""
model_path
=
snapshot_download
(
"microsoft/Phi-4-multimodal-instruct"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path
=
os
.
path
.
join
(
model_path
,
"vision-lora"
)
llm
=
LLM
(
model
=
model_path
,
trust_remote_code
=
True
,
max_model_len
=
10000
,
max_num_seqs
=
2
,
limit_mm_per_prompt
=
{
"image"
:
len
(
image_urls
)},
enable_lora
=
True
,
max_lora_rank
=
320
,
lora_extra_vocab_size
=
0
,
)
lora_request
=
LoRARequest
(
"vision"
,
1
,
vision_lora_path
)
# To maintain code compatibility in this script, we add LoRA here.
llm
.
llm_engine
.
add_lora
(
lora_request
=
lora_request
)
# You can also add LoRA using:
# llm.generate(prompts, lora_request=lora_request,...)
placeholders
=
""
.
join
(
f
"<|image_
{
i
}
|>"
for
i
,
_
in
enumerate
(
image_urls
,
start
=
1
))
prompt
=
f
"<|user|>
{
placeholders
}{
question
}
<|end|><|assistant|>"
stop_token_ids
=
None
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
stop_token_ids
=
stop_token_ids
,
image_data
=
[
fetch_image
(
url
)
for
url
in
image_urls
],
chat_template
=
None
,
)
def
load_qwen_vl_chat
(
question
:
str
,
def
load_qwen_vl_chat
(
question
:
str
,
image_urls
:
list
[
str
])
->
ModelRequestData
:
image_urls
:
list
[
str
])
->
ModelRequestData
:
model_name
=
"Qwen/Qwen-VL-Chat"
model_name
=
"Qwen/Qwen-VL-Chat"
...
@@ -459,6 +502,7 @@ model_example_map = {
...
@@ -459,6 +502,7 @@ model_example_map = {
"mllama"
:
load_mllama
,
"mllama"
:
load_mllama
,
"NVLM_D"
:
load_nvlm_d
,
"NVLM_D"
:
load_nvlm_d
,
"phi3_v"
:
load_phi3v
,
"phi3_v"
:
load_phi3v
,
"phi4_mm"
:
load_phi4mm
,
"pixtral_hf"
:
load_pixtral_hf
,
"pixtral_hf"
:
load_pixtral_hf
,
"qwen_vl_chat"
:
load_qwen_vl_chat
,
"qwen_vl_chat"
:
load_qwen_vl_chat
,
"qwen2_vl"
:
load_qwen2_vl
,
"qwen2_vl"
:
load_qwen2_vl
,
...
...
vllm/model_executor/models/phi4mm.py
View file @
952a0749
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
)
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalInputs
,
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalInputs
,
NestedTensors
...
@@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"""
"""
Implements the Phi-4-multimodal-instruct model in VLLM.
Implements the Phi-4-multimodal-instruct model in VLLM.
"""
"""
# LoRA specific attributes
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"qkv_proj"
,
"qkv_proj"
,
...
@@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
"gate_up_proj"
,
"gate_up_proj"
,
],
],
}
}
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
]
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -1801,3 +1795,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1801,3 +1795,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"model."
,
connector
=
[
"audio_projection_for_vision"
,
"audio_projection"
],
tower_model
=
[
"vision_encoder"
,
"embed_tokens_extend"
],
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment