Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba5106e5
Unverified
Commit
ba5106e5
authored
Feb 23, 2025
by
Isotr0py
Committed by
GitHub
Feb 23, 2025
Browse files
[LMM] Implement merged multimodal processor for whisper (#13278)
parent
d5ca2110
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
150 additions
and
83 deletions
+150
-83
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+6
-5
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+132
-74
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+4
-1
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+8
-3
No files found.
tests/models/multimodal/processing/test_common.py
View file @
ba5106e5
...
...
@@ -83,11 +83,11 @@ def _test_processing_correctness(
}
tokenizer_encode_kwargs
=
{}
if
model_config
.
hf_config
.
model_type
==
"mllama"
:
# For
Mllama
, tokenizer will always add bos_token
at the beginning of
# prompt by default, causing hf_processor outputs
incorrect token ids.
# So we need use `add_special_tokens=False` here
to leave bos_token
# to be added by the processor.
if
model_config
.
hf_config
.
model_type
in
(
"mllama"
,
"whisper"
)
:
# For
some encoder-decoder models
, tokenizer will always add bos_token
#
at the beginning of
prompt by default, causing hf_processor outputs
#
incorrect token ids.
So we need use `add_special_tokens=False` here
# to
leave bos_token to
be added by the processor.
tokenizer_encode_kwargs
=
{
"add_special_tokens"
:
False
}
for
batch_idx
in
range
(
num_batches
):
...
...
@@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct"
,
"Qwen/Qwen2-Audio-7B-Instruct"
,
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
"openai/whisper-large-v3"
,
])
@
pytest
.
mark
.
parametrize
(
"hit_rate"
,
[
0.3
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"num_batches"
,
[
32
])
...
...
vllm/model_executor/models/whisper.py
View file @
ba5106e5
...
...
@@ -4,15 +4,15 @@ import math
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
(
BatchFeature
,
WhisperConfig
,
WhisperFeatureExtractor
,
WhisperProcessor
)
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
DummyData
,
InputContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.audio
import
resample_audio
from
vllm.sequence
import
SequenceData
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
MultiModalDataDict
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
make_layers
...
...
@@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return
loaded_params
def
get_max_whisper_audio_tokens
(
ctx
:
InputContext
)
->
int
:
return
ctx
.
model_config
.
hf_config
.
max_source_positions
class
WhisperProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
)
->
WhisperConfig
:
return
self
.
ctx
.
get_hf_config
(
WhisperConfig
)
def
dummy_encoder_data_for_whisper
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
assert
mm_counts
[
"audio"
]
==
1
num_tokens
=
get_max_whisper_audio_tokens
(
ctx
)
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
chunk_length
=
processor
.
feature_extractor
.
chunk_length
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
num_samples
=
chunk_length
*
sampling_rate
return
DummyData
(
SequenceData
.
from_prompt_token_counts
((
0
,
num_tokens
)),
{
"audio"
:
[(
np
.
zeros
(
num_samples
),
sampling_rate
)]},
)
def
get_hf_processor
(
self
,
sampling_rate
:
Optional
[
int
]
=
None
)
->
WhisperProcessor
:
return
self
.
ctx
.
get_hf_processor
(
WhisperProcessor
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"audio"
:
1
}
def
input_processor_for_whisper
(
ctx
:
InputContext
,
inputs
):
multi_modal_data
=
inputs
[
"encoder"
][
"multi_modal_data"
]
if
isinstance
(
multi_modal_data
[
"audio"
],
list
):
assert
len
(
multi_modal_data
[
"audio"
])
==
1
multi_modal_data
[
"audio"
]
=
multi_modal_data
[
"audio"
][
0
]
# Resample and process audio
audio
,
orig_sr
=
multi_modal_data
[
"audio"
]
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
target_sr
=
processor
.
feature_extractor
.
sampling_rate
audio
=
resample_audio
(
audio
,
orig_sr
=
orig_sr
,
target_sr
=
target_sr
)
multi_modal_data
[
"audio"
]
=
(
audio
,
target_sr
)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens
=
get_max_whisper_audio_tokens
(
ctx
)
inputs
[
"encoder"
][
"prompt_token_ids"
]
=
[
0
]
*
num_tokens
return
inputs
def
get_feature_extractor
(
self
)
->
WhisperFeatureExtractor
:
hf_processor
=
self
.
get_hf_processor
()
feature_extractor
=
hf_processor
.
feature_extractor
# type: ignore
assert
isinstance
(
feature_extractor
,
WhisperFeatureExtractor
)
return
feature_extractor
def
get_max_audio_tokens
(
self
)
->
int
:
return
self
.
get_hf_config
().
max_source_positions
def
input_mapper_for_whisper
(
ctx
:
InputContext
,
multi_modal_data
:
Union
[
np
.
ndarray
,
List
[
np
.
ndarray
]]
,
)
->
MultiModalKwargs
:
if
not
isinstance
(
multi_modal_data
,
list
)
:
multi_modal_data
=
[
multi_modal_data
]
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]
:
return
{
"audio"
:
self
.
get_max_audio_tokens
()}
assert
len
(
multi_modal_data
)
==
1
if
len
(
multi_modal_data
)
==
0
:
return
MultiModalKwargs
()
class
WhisperDummyInputsBuilder
(
BaseDummyInputsBuilder
[
WhisperProcessingInfo
]):
processor
=
cached_processor_from_config
(
ctx
.
model_config
)
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
feature_extractor
=
self
.
info
.
get_feature_extractor
()
sampling_rate
=
feature_extractor
.
sampling_rate
audio_len
=
feature_extractor
.
chunk_length
*
sampling_rate
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
mm_data
=
{
"audio"
:
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
)
}
audios
=
[
audio
for
audio
,
_
in
multi_modal_data
]
return
ProcessorInputs
(
prompt_text
=
"<|startoftranscript|>"
*
num_audios
,
mm_data
=
mm_data
,
)
kwargs
=
processor
(
audios
,
sampling_rate
=
sampling_rate
,
return_tensors
=
"pt"
)
kwargs
[
"input_features"
]
=
kwargs
[
"input_features"
].
squeeze
(
0
).
to
(
ctx
.
model_config
.
dtype
)
return
MultiModalKwargs
(
kwargs
)
class
WhisperMultiModalProcessor
(
EncDecMultiModalProcessor
[
WhisperProcessingInfo
]):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
feature_extractor
=
self
.
info
.
get_feature_extractor
()
return
MultiModalDataParser
(
target_sr
=
feature_extractor
.
sampling_rate
)
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
# num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return
[
0
]
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
if
mm_data
:
feature_extractor
=
self
.
info
.
get_feature_extractor
(
**
mm_kwargs
)
mm_data
=
dict
(
audio
=
mm_data
.
pop
(
"audios"
))
mm_kwargs
=
dict
(
**
mm_kwargs
,
sampling_rate
=
feature_extractor
.
sampling_rate
,
)
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
)
if
"labels"
in
processed_outputs
:
processed_outputs
[
"input_ids"
]
=
processed_outputs
.
pop
(
"labels"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
))
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
num_tokens
=
self
.
info
.
get_max_audio_tokens
()
return
[
PromptReplacement
(
modality
=
"audio"
,
target
=
[
0
],
replacement
=
[
0
]
*
num_tokens
,
)
]
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_whisper
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_whisper
)
@
MULTIMODAL_REGISTRY
.
register_input_mapper
(
"audio"
,
input_mapper_for_whisper
)
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
"audio"
,
get_max_whisper_audio_tokens
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
WhisperMultiModalProcessor
,
info
=
WhisperProcessingInfo
,
dummy_inputs
=
WhisperDummyInputsBuilder
)
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsTranscription
,
SupportsMultiModal
):
packed_modules_mapping
=
{
...
...
@@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if
not
isinstance
(
input_features
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
input_features
)
}
"
)
input_features
=
[
feat
.
to
(
self
.
dtype
)
for
feat
in
input_features
]
input_features
=
torch
.
cat
(
[
feat
.
to
(
self
.
dtype
)
for
feat
in
input_features
])
return
WhisperAudioInputs
(
input_features
=
input_features
)
...
...
vllm/multimodal/processing.py
View file @
ba5106e5
...
...
@@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
"""Create input prompt for the encoder."""
"""
Create input prompt for the encoder. HF processor will be applied on
this prompt during profiling and generation.
"""
raise
NotImplementedError
def
apply
(
...
...
vllm/multimodal/profiling.py
View file @
ba5106e5
...
...
@@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids
=
(
mm_inputs
[
"prompt_token_ids"
]
if
not
is_encoder_data
else
mm_inputs
[
"encoder_prompt_token_ids"
])
# type: ignore
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
...
...
@@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
# V0 does not support chunked prefill.
if
(
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
)
or
is_encoder_data
:
if
total_len
>
seq_len
:
if
total_len
>
seq_len
and
not
is_encoder_data
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
...
...
@@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
seq_data
=
SequenceData
.
from_prompt_token_counts
(
(
0
,
max
(
seq_len
,
total_len
))),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment