Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba5106e5
Unverified
Commit
ba5106e5
authored
Feb 23, 2025
by
Isotr0py
Committed by
GitHub
Feb 23, 2025
Browse files
[LMM] Implement merged multimodal processor for whisper (#13278)
parent
d5ca2110
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
150 additions
and
83 deletions
+150
-83
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+6
-5
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+132
-74
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+4
-1
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+8
-3
No files found.
tests/models/multimodal/processing/test_common.py
View file @
ba5106e5
...
@@ -83,11 +83,11 @@ def _test_processing_correctness(
...
@@ -83,11 +83,11 @@ def _test_processing_correctness(
}
}
tokenizer_encode_kwargs
=
{}
tokenizer_encode_kwargs
=
{}
if
model_config
.
hf_config
.
model_type
==
"mllama"
:
if
model_config
.
hf_config
.
model_type
in
(
"mllama"
,
"whisper"
)
:
# For
Mllama
, tokenizer will always add bos_token
at the beginning of
# For
some encoder-decoder models
, tokenizer will always add bos_token
# prompt by default, causing hf_processor outputs
incorrect token ids.
#
at the beginning of
prompt by default, causing hf_processor outputs
# So we need use `add_special_tokens=False` here
to leave bos_token
#
incorrect token ids.
So we need use `add_special_tokens=False` here
# to be added by the processor.
# to
leave bos_token to
be added by the processor.
tokenizer_encode_kwargs
=
{
"add_special_tokens"
:
False
}
tokenizer_encode_kwargs
=
{
"add_special_tokens"
:
False
}
for
batch_idx
in
range
(
num_batches
):
for
batch_idx
in
range
(
num_batches
):
...
@@ -173,6 +173,7 @@ def _test_processing_correctness(
...
@@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"openai/whisper-large-v3"
,
])
])
@
pytest
.
mark
.
parametrize
(
"hit_rate"
,
[
0.3
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"hit_rate"
,
[
0.3
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"num_batches"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_batches"
,
[
32
])
...
...
vllm/model_executor/models/whisper.py
View file @
ba5106e5
...
@@ -4,15 +4,15 @@ import math
...
@@ -4,15 +4,15 @@ import math
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
Union
)
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
(
BatchFeature
,
WhisperConfig
,
WhisperFeatureExtractor
,
WhisperProcessor
)
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
DummyData
,
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...
@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
NestedTensors
)
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.audio
import
resample_audio
from
vllm.multimodal.parse
import
(
MultiModalDataDict
,
MultiModalDataItems
,
from
vllm.sequence
import
SequenceData
MultiModalDataParser
)
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
make_layers
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
make_layers
...
@@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
...
@@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return
loaded_params
return
loaded_params
def
get_max_whisper_audio_tokens
(
ctx
:
InputContext
)
->
int
:
class
WhisperProcessingInfo
(
BaseProcessingInfo
):
return
ctx
.
model_config
.
hf_config
.
max_source_positions
def
get_hf_config
(
self
)
->
WhisperConfig
:
return
self
.
ctx
.
get_hf_config
(
WhisperConfig
)
def
dummy_encoder_data_for_whisper
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
def
get_hf_processor
(
self
,
assert
mm_counts
[
"audio"
]
==
1
sampling_rate
:
Optional
[
int
]
=
None
num_tokens
=
get_max_whisper_audio_tokens
(
ctx
)
)
->
WhisperProcessor
:
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
return
self
.
ctx
.
get_hf_processor
(
WhisperProcessor
)
chunk_length
=
processor
.
feature_extractor
.
chunk_length
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
num_samples
=
chunk_length
*
sampling_rate
return
{
"audio"
:
1
}
return
DummyData
(
SequenceData
.
from_prompt_token_counts
((
0
,
num_tokens
)),
def
get_feature_extractor
(
self
)
->
WhisperFeatureExtractor
:
{
"audio"
:
[(
np
.
zeros
(
num_samples
),
sampling_rate
)]},
hf_processor
=
self
.
get_hf_processor
()
)
feature_extractor
=
hf_processor
.
feature_extractor
# type: ignore
assert
isinstance
(
feature_extractor
,
WhisperFeatureExtractor
)
return
feature_extractor
def
input_processor_for_whisper
(
ctx
:
InputContext
,
inputs
):
multi_modal_data
=
inputs
[
"encoder"
][
"multi_modal_data"
]
def
get_max_audio_tokens
(
self
)
->
int
:
if
isinstance
(
multi_modal_data
[
"audio"
],
list
):
return
self
.
get_hf_config
().
max_source_positions
assert
len
(
multi_modal_data
[
"audio"
])
==
1
multi_modal_data
[
"audio"
]
=
multi_modal_data
[
"audio"
][
0
]
def
get_mm_max_tokens_per_item
(
# Resample and process audio
self
,
audio
,
orig_sr
=
multi_modal_data
[
"audio"
]
seq_len
:
int
,
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
mm_counts
:
Mapping
[
str
,
int
],
target_sr
=
processor
.
feature_extractor
.
sampling_rate
)
->
Mapping
[
str
,
int
]:
audio
=
resample_audio
(
audio
,
orig_sr
=
orig_sr
,
target_sr
=
target_sr
)
return
{
"audio"
:
self
.
get_max_audio_tokens
()}
multi_modal_data
[
"audio"
]
=
(
audio
,
target_sr
)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens
=
get_max_whisper_audio_tokens
(
ctx
)
class
WhisperDummyInputsBuilder
(
BaseDummyInputsBuilder
[
WhisperProcessingInfo
]):
inputs
[
"encoder"
][
"prompt_token_ids"
]
=
[
0
]
*
num_tokens
return
inputs
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
def
input_mapper_for_whisper
(
mm_counts
:
Mapping
[
str
,
int
],
ctx
:
InputContext
,
)
->
ProcessorInputs
:
multi_modal_data
:
Union
[
np
.
ndarray
,
List
[
np
.
ndarray
]],
feature_extractor
=
self
.
info
.
get_feature_extractor
()
)
->
MultiModalKwargs
:
if
not
isinstance
(
multi_modal_data
,
list
):
sampling_rate
=
feature_extractor
.
sampling_rate
multi_modal_data
=
[
multi_modal_data
]
audio_len
=
feature_extractor
.
chunk_length
*
sampling_rate
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
assert
len
(
multi_modal_data
)
==
1
mm_data
=
{
if
len
(
multi_modal_data
)
==
0
:
"audio"
:
return
MultiModalKwargs
()
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
)
}
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
return
ProcessorInputs
(
prompt_text
=
"<|startoftranscript|>"
*
num_audios
,
audios
=
[
audio
for
audio
,
_
in
multi_modal_data
]
mm_data
=
mm_data
,
)
kwargs
=
processor
(
audios
,
sampling_rate
=
sampling_rate
,
return_tensors
=
"pt"
)
class
WhisperMultiModalProcessor
(
kwargs
[
"input_features"
]
=
kwargs
[
"input_features"
].
squeeze
(
0
).
to
(
EncDecMultiModalProcessor
[
WhisperProcessingInfo
]):
ctx
.
model_config
.
dtype
)
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
MultiModalKwargs
(
kwargs
)
feature_extractor
=
self
.
info
.
get_feature_extractor
()
return
MultiModalDataParser
(
target_sr
=
feature_extractor
.
sampling_rate
)
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_whisper
)
def
create_encoder_prompt
(
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_whisper
)
self
,
@
MULTIMODAL_REGISTRY
.
register_input_mapper
(
"audio"
,
input_mapper_for_whisper
)
prompt
:
Union
[
str
,
list
[
int
]],
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
mm_data
:
MultiModalDataDict
,
"audio"
,
get_max_whisper_audio_tokens
)
)
->
Union
[
str
,
list
[
int
]]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
# num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return
[
0
]
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
if
mm_data
:
feature_extractor
=
self
.
info
.
get_feature_extractor
(
**
mm_kwargs
)
mm_data
=
dict
(
audio
=
mm_data
.
pop
(
"audios"
))
mm_kwargs
=
dict
(
**
mm_kwargs
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
)
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
)
if
"labels"
in
processed_outputs
:
processed_outputs
[
"input_ids"
]
=
processed_outputs
.
pop
(
"labels"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
))
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
num_tokens
=
self
.
info
.
get_max_audio_tokens
()
return
[
PromptReplacement
(
modality
=
"audio"
,
target
=
[
0
],
replacement
=
[
0
]
*
num_tokens
,
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
WhisperMultiModalProcessor
,
info
=
WhisperProcessingInfo
,
dummy_inputs
=
WhisperDummyInputsBuilder
)
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsTranscription
,
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsTranscription
,
SupportsMultiModal
):
SupportsMultiModal
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if
not
isinstance
(
input_features
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
input_features
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio features. "
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
input_features
)
}
"
)
f
"Got type:
{
type
(
input_features
)
}
"
)
input_features
=
[
feat
.
to
(
self
.
dtype
)
for
feat
in
input_features
]
input_features
=
torch
.
cat
(
[
feat
.
to
(
self
.
dtype
)
for
feat
in
input_features
])
return
WhisperAudioInputs
(
input_features
=
input_features
)
return
WhisperAudioInputs
(
input_features
=
input_features
)
...
...
vllm/multimodal/processing.py
View file @
ba5106e5
...
@@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
)
->
Union
[
str
,
list
[
int
]]:
"""Create input prompt for the encoder."""
"""
Create input prompt for the encoder. HF processor will be applied on
this prompt during profiling and generation.
"""
raise
NotImplementedError
raise
NotImplementedError
def
apply
(
def
apply
(
...
...
vllm/multimodal/profiling.py
View file @
ba5106e5
...
@@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids
=
(
mm_inputs
[
"prompt_token_ids"
]
if
not
is_encoder_data
else
mm_inputs
[
"encoder_prompt_token_ids"
])
# type: ignore
total_placeholders_by_modality
=
{
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
...
@@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
# V0 does not support chunked prefill.
# V0 does not support chunked prefill.
if
(
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
)
or
is_encoder_data
:
if
(
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
)
or
is_encoder_data
:
if
total_len
>
seq_len
:
if
total_len
>
seq_len
and
not
is_encoder_data
:
logger
.
warning
(
logger
.
warning
(
"The context length (%d) of the model is too short "
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"to hold the multi-modal embeddings in the worst case "
...
@@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
total_placeholders_by_modality
)
total_placeholders_by_modality
)
return
DummyData
(
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
seq_data
=
SequenceData
.
from_prompt_token_counts
(
(
0
,
max
(
seq_len
,
total_len
))),
multi_modal_data
=
None
,
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_placeholders
=
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment