Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3faa8bee
Unverified
Commit
3faa8bee
authored
Dec 23, 2025
by
Patrick von Platen
Committed by
GitHub
Dec 23, 2025
Browse files
adapt voxtral (#31095)
Signed-off-by:
Patrick von Platen
<
patrick.v.platen@gmail.com
>
parent
b10d47e0
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
740 additions
and
99 deletions
+740
-99
tests/models/multimodal/generation/test_voxtral.py
tests/models/multimodal/generation/test_voxtral.py
+1
-0
tests/models/registry.py
tests/models/registry.py
+5
-0
vllm/config/model.py
vllm/config/model.py
+4
-0
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+10
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+4
-0
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+29
-11
vllm/model_executor/models/voxtral_streaming.py
vllm/model_executor/models/voxtral_streaming.py
+243
-0
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+91
-81
vllm/model_executor/models/whisper_utils.py
vllm/model_executor/models/whisper_utils.py
+299
-0
vllm/transformers_utils/configs/mistral.py
vllm/transformers_utils/configs/mistral.py
+31
-3
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+9
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+14
-4
No files found.
tests/models/multimodal/generation/test_voxtral.py
View file @
3faa8bee
...
...
@@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets):
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
message
.
content
==
"In the first audio clip, you hear a brief"
assert
choice
.
finish_reason
==
"length"
tests/models/registry.py
View file @
3faa8bee
...
...
@@ -860,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# disable this temporarily until we support HF format
is_available_online
=
False
,
),
"VoxtralStreamingGeneration"
:
_HfExamplesInfo
(
"<place-holder>"
,
# disable this temporarily until we support HF format
is_available_online
=
False
,
),
# [Encoder-decoder]
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3-turbo"
,
...
...
vllm/config/model.py
View file @
3faa8bee
...
...
@@ -1542,6 +1542,10 @@ class ModelConfig:
def
is_multimodal_raw_input_only_model
(
self
)
->
bool
:
return
self
.
_model_info
.
supports_multimodal_raw_input_only
@
property
def
requires_raw_input_tokens
(
self
)
->
bool
:
return
self
.
_model_info
.
requires_raw_input_tokens
@
property
def
is_cross_encoder
(
self
)
->
bool
:
return
(
...
...
vllm/model_executor/models/interfaces.py
View file @
3faa8bee
...
...
@@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol):
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
requires_raw_input_tokens
:
ClassVar
[
bool
]
=
False
"""
A flag that indicates this model processes input id tokens
in their raw form and not input embeddings.
"""
merge_by_field_config
:
ClassVar
[
bool
|
None
]
=
None
"""
[DEPRECATED] A flag that indicates which implementation of
...
...
@@ -306,6 +312,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
return
getattr
(
model
,
"supports_multimodal_raw_input_only"
,
False
)
def
requires_raw_input_tokens
(
model
:
type
[
object
]
|
object
)
->
bool
:
return
getattr
(
model
,
"requires_raw_input_tokens"
,
False
)
def
supports_multimodal_encoder_tp_data
(
model
:
type
[
object
]
|
object
)
->
bool
:
return
getattr
(
model
,
"supports_encoder_tp_data"
,
False
)
...
...
vllm/model_executor/models/registry.py
View file @
3faa8bee
...
...
@@ -46,6 +46,7 @@ from .interfaces import (
has_noops
,
is_attention_free
,
is_hybrid
,
requires_raw_input_tokens
,
supports_cross_encoding
,
supports_mamba_prefix_caching
,
supports_multimodal
,
...
...
@@ -422,6 +423,7 @@ _MULTIMODAL_MODELS = {
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"VoxtralForConditionalGeneration"
:
(
"voxtral"
,
"VoxtralForConditionalGeneration"
),
# noqa: E501
"VoxtralStreamingGeneration"
:
(
"voxtral_streaming"
,
"VoxtralStreamingGeneration"
),
# noqa: E501
# [Encoder-decoder]
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
}
...
...
@@ -539,6 +541,7 @@ class _ModelInfo:
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal_raw_input_only
:
bool
requires_raw_input_tokens
:
bool
supports_multimodal_encoder_tp_data
:
bool
supports_pp
:
bool
has_inner_state
:
bool
...
...
@@ -562,6 +565,7 @@ class _ModelInfo:
supports_multimodal_raw_input_only
=
supports_multimodal_raw_input_only
(
model
),
requires_raw_input_tokens
=
requires_raw_input_tokens
(
model
),
supports_multimodal_encoder_tp_data
=
supports_multimodal_encoder_tp_data
(
model
),
...
...
vllm/model_executor/models/voxtral.py
View file @
3faa8bee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
inspect
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
...
...
@@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
self
,
audio_length
:
int
,
)
->
int
:
pad_audio_length
=
self
.
_audio_processor
.
next_multiple_of_chunk_frames
(
audio_length
,
self
.
sampling_rate
)
return
ceil
(
pad_audio_length
/
(
self
.
sampling_rate
//
self
.
frame_rate
))
return
ceil
(
audio_length
/
(
self
.
sampling_rate
//
self
.
frame_rate
))
def
__call__
(
self
,
...
...
@@ -158,6 +156,13 @@ class VoxtralProcessorAdapter:
assert
audio
.
ndim
==
1
# pad if necessary
# TODO(Patrick) - remove once mistral-common is bumped
sig
=
inspect
.
signature
(
self
.
_audio_processor
.
pad
)
if
"is_online_streaming"
in
sig
.
parameters
:
audio
=
self
.
_audio_processor
.
pad
(
audio
,
self
.
sampling_rate
,
is_online_streaming
=
False
)
else
:
audio
=
self
.
_audio_processor
.
pad
(
audio
,
self
.
sampling_rate
)
audio_tokens
=
[
self
.
begin_audio_token_id
]
+
[
...
...
@@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
remapping_rules
=
[
(
r
"mm_streams_embeddings.embedding_module\.(.*)"
,
r
"\1"
),
(
r
"mm_whisper_embeddings\.(.*)"
,
r
"\1"
),
(
r
"audio_language_projection\.(.*)"
,
r
"audio_language_adapter.\1"
),
(
...
...
@@ -535,12 +541,15 @@ class VoxtralForConditionalGeneration(
def
llm_weights_generator
():
nonlocal
loaded_weights
for
name
,
w
in
weights
:
is_encoder
=
(
name
.
startswith
(
"mm_whisper_embeddings"
)
and
not
name
.
startswith
(
"mm_whisper_embeddings.tok_embeddings"
)
and
not
name
.
startswith
(
"mm_whisper_embeddings.audio_language_projection"
)
is_encoder
=
False
for
k
in
[
"mm_whisper_embeddings"
,
"mm_streams_embeddings.embedding_module"
,
]:
is_encoder
|=
(
name
.
startswith
(
k
)
and
not
name
.
startswith
(
f
"
{
k
}
.tok_embeddings"
)
and
not
name
.
startswith
(
f
"
{
k
}
.audio_language_projection"
)
)
for
pattern
,
repl
in
remapping_rules
:
...
...
@@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
mistral_remapping
=
[
(
r
"mm_streams_embeddings.embedding_module\.(.*)"
,
r
"\1"
),
(
r
"whisper_encoder\.conv_layers\.0\.(weight|bias)"
,
r
"whisper_encoder.conv1.\1"
,
...
...
@@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
r
"whisper_encoder\.conv_layers\.1\.(weight|bias)"
,
r
"whisper_encoder.conv2.\1"
,
),
(
r
"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)"
,
r
"whisper_encoder.conv1.\1"
,
),
# noqa: E501
(
r
"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)"
,
r
"whisper_encoder.conv2.\1"
,
),
# noqa: E501
(
r
"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)"
,
# noqa: E501
r
"whisper_encoder.layers.\1.self_attn.\2_proj.\3"
,
...
...
vllm/model_executor/models/voxtral_streaming.py
0 → 100644
View file @
3faa8bee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Mapping
import
torch
from
vllm.config.vllm
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.voxtral
import
(
VoxtralDummyInputsBuilder
,
VoxtralForConditionalGeneration
,
VoxtralMultiModalProcessor
,
VoxtralProcessingInfo
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
_I
,
BaseMultiModalProcessorCache
from
vllm.multimodal.inputs
import
(
MultiModalKwargsOptionalItems
,
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
MultiModalPromptUpdates
,
PlaceholderFeaturesInfo
,
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
(
_flatten_embeddings
,
)
logger
=
init_logger
(
__name__
)
class
VoxtralStreamingMultiModalProcessor
(
VoxtralMultiModalProcessor
):
def
__init__
(
self
,
info
:
_I
,
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
cache
:
BaseMultiModalProcessorCache
|
None
=
None
,
)
->
None
:
# streaming can't make use of a cache yet
super
().
__init__
(
info
,
dummy_inputs
,
cache
=
None
)
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
prompt_ids
:
list
[
int
],
mm_kwargs
:
MultiModalKwargsOptionalItems
,
mm_prompt_updates
:
MultiModalPromptUpdates
,
is_update_applied
:
bool
,
)
->
tuple
[
list
[
int
],
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
# there are no placeholder audio tokens for streaming
# so we need to build the place placeholder positions manually
# in streaming there is always only one audio input
audios
=
mm_kwargs
.
get
(
"audio"
,
[])
assert
len
(
audios
)
==
1
,
(
f
"Expected only one audio input for streaming, got
{
mm_kwargs
=
}
"
)
tokenizer
=
self
.
info
.
get_tokenizer
()
audio_config
=
tokenizer
.
instruct
.
audio_encoder
.
audio_config
num_audio_samples
=
audios
[
0
][
"audio_arrays"
].
data
.
shape
[
0
]
length
=
audio_config
.
num_audio_tokens
(
num_audio_samples
)
features_info
=
PlaceholderFeaturesInfo
(
modality
=
"audio"
,
item_idx
=
0
,
start_idx
=
0
,
tokens
=
length
*
[
0
],
# only used for length computation, so we can take dummy inputs
is_embed
=
None
,
)
return
prompt_ids
,
{
"audio"
:
[
features_info
]}
class
TimeEmbedding
(
torch
.
nn
.
Module
):
"""Sinusoidal Embedding for encoding time"""
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
torch
.
exp
(
-
math
.
log
(
self
.
theta
)
*
torch
.
arange
(
self
.
dim
//
2
).
float
()
/
(
self
.
dim
//
2
)
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
def
forward
(
self
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
t
[...,
None
]
# (B,) -> (B, 1) or (B, T) -> (B, T, 1)
inv_freq
=
self
.
inv_freq
.
to
(
device
=
t
.
device
,
dtype
=
t
.
dtype
)
emb
=
(
t
*
inv_freq
)
# (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
return
torch
.
cat
((
emb
.
cos
(),
emb
.
sin
()),
dim
=-
1
)
# (B, D) or (B, T, D)
@
MULTIMODAL_REGISTRY
.
register_processor
(
VoxtralStreamingMultiModalProcessor
,
info
=
VoxtralProcessingInfo
,
dummy_inputs
=
VoxtralDummyInputsBuilder
,
)
class
VoxtralStreamingGeneration
(
VoxtralForConditionalGeneration
):
requires_raw_input_tokens
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
time_embedding
:
TimeEmbedding
=
TimeEmbedding
(
dim
=
self
.
config
.
text_config
.
hidden_size
)
audio_config
=
self
.
tokenizer
.
instruct
.
audio_encoder
.
audio_config
_n_delay_tokens
=
(
audio_config
.
frame_rate
*
audio_config
.
transcription_delay_ms
/
1000
)
assert
_n_delay_tokens
.
is_integer
(),
(
f
"n_delay_tokens must be integer, got
{
_n_delay_tokens
}
"
)
self
.
n_delay_tokens
=
int
(
_n_delay_tokens
)
@
property
def
audio_config
(
self
):
return
self
.
tokenizer
.
instruct
.
audio_encoder
.
audio_config
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""Pass post-conv embeddings directly as input"""
# for streaming we simply flatten the multimodal embeddings
# to be in tensor format, we treat the input ids later
assert
multimodal_embeddings
is
not
None
assert
len
(
multimodal_embeddings
)
>
0
,
(
"For streaming you must provide a multimodal_embedding at every step."
)
mm_embeds_flat
=
_flatten_embeddings
(
multimodal_embeddings
)
return
mm_embeds_flat
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
assert
inputs_embeds
is
not
None
assert
input_ids
is
not
None
pool_size
=
self
.
config
.
audio_config
.
block_pool_size
inputs_embeds
=
inputs_embeds
.
view
(
inputs_embeds
.
shape
[
0
]
*
pool_size
,
inputs_embeds
.
shape
[
1
]
//
pool_size
)
audio_hidden_states
=
self
.
whisper_encoder
.
whisper_encoder
.
forward_layers
(
inputs_embeds
)
num_tokens
,
audio_hidden_size
=
audio_hidden_states
.
shape
assert
num_tokens
%
self
.
downsample_factor
==
0
audio_hidden_states
=
audio_hidden_states
.
reshape
(
num_tokens
//
self
.
downsample_factor
,
audio_hidden_size
*
self
.
downsample_factor
,
)
audio_text_embeds
=
self
.
audio_language_adapter
(
audio_hidden_states
)
text_embeds
=
self
.
language_model
.
embed_input_ids
(
input_ids
)
# sum pool text and audio embeddings
inputs_embeds
=
audio_text_embeds
+
text_embeds
time_tensor
=
torch
.
tensor
(
[
self
.
n_delay_tokens
],
device
=
inputs_embeds
.
device
,
dtype
=
inputs_embeds
.
dtype
,
)
inputs_embeds
=
inputs_embeds
+
self
.
time_embedding
(
time_tensor
)
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
embed_multimodal
(
self
,
**
kwargs
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
...]
|
None
:
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
audio_inputs
=
self
.
_parse_and_validate_audio_arrays
(
**
kwargs
)
assert
audio_inputs
is
not
None
,
(
"For streaming you must provide an audio input at every step."
)
multiple_of
=
self
.
audio_config
.
raw_audio_length_per_tok
assert
all
(
(
this_audio
:
=
audio
.
shape
[
0
])
%
multiple_of
==
0
for
audio
in
audio_inputs
),
(
f
"Every input audio waveform has to be a multiple of
{
multiple_of
}
, but"
f
" one is
{
this_audio
}
with
{
(
this_audio
/
multiple_of
)
=
}
."
)
mel_features
=
[
self
.
whisper_encoder
.
compute_whisper_melspec
(
audio
).
to
(
self
.
whisper_encoder
.
dtype
)
for
audio
in
audio_inputs
]
seq_lens
=
[
mel
.
shape
[
1
]
for
mel
in
mel_features
]
# [total_num_20ms_frames, hidden_size]
audio_embeddings
=
self
.
whisper_encoder
.
whisper_encoder
.
forward_conv
(
mel_features
)[
0
]
conv_stride
=
self
.
whisper_encoder
.
whisper_encoder
.
total_stride
audio_embeddings_per_sample
=
audio_embeddings
.
split
(
[
s
//
conv_stride
for
s
in
seq_lens
],
dim
=
0
)
# audio_embeddings per sample need to be divisible by 4
pool_size
=
self
.
config
.
audio_config
.
block_pool_size
assert
all
(
(
this_shape
:
=
sample
.
shape
[
0
])
%
pool_size
==
0
for
sample
in
audio_embeddings_per_sample
),
f
"Every audio embedding has to be a multiple of 4, but one is
{
this_shape
}
."
audio_embeddings_per_sample
=
[
e
.
view
(
e
.
shape
[
0
]
//
pool_size
,
e
.
shape
[
1
]
*
pool_size
)
for
e
in
audio_embeddings_per_sample
]
return
audio_embeddings_per_sample
vllm/model_executor/models/whisper.py
View file @
3faa8bee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
contextlib
import
nullcontext
from
functools
import
partial
from
typing
import
Annotated
,
Literal
,
cast
import
numpy
as
np
...
...
@@ -16,7 +18,10 @@ from transformers import (
)
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention.layer
import
Attention
,
AttentionType
from
vllm.attention.backends.abstract
import
(
AttentionType
,
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layers.cross_attention
import
CrossAttention
from
vllm.attention.layers.mm_encoder_attention
import
MMEncoderAttention
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
...
...
@@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.whisper_utils
import
(
ISO639_1_SUPPORTED_LANGS
,
WhisperAttentionWithBlockPooling
,
WhisperCausalConv1d
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
...
...
@@ -64,67 +74,11 @@ from .utils import (
logger
=
init_logger
(
__name__
)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS
=
{
"af"
:
"Afrikaans"
,
"ar"
:
"Arabic"
,
"hy"
:
"Armenian"
,
"az"
:
"Azerbaijani"
,
"be"
:
"Belarusian"
,
"bs"
:
"Bosnian"
,
"bg"
:
"Bulgarian"
,
"ca"
:
"Catalan"
,
"zh"
:
"Chinese"
,
"hr"
:
"Croatian"
,
"cs"
:
"Czech"
,
"da"
:
"Danish"
,
"nl"
:
"Dutch"
,
"en"
:
"English"
,
"et"
:
"Estonian"
,
"fi"
:
"Finnish"
,
"fr"
:
"French"
,
"gl"
:
"Galician"
,
"de"
:
"German"
,
"el"
:
"Greek"
,
"he"
:
"Hebrew"
,
"hi"
:
"Hindi"
,
"hu"
:
"Hungarian"
,
"is"
:
"Icelandic"
,
"id"
:
"Indonesian"
,
"it"
:
"Italian"
,
"ja"
:
"Japanese"
,
"kn"
:
"Kannada"
,
"kk"
:
"Kazakh"
,
"ko"
:
"Korean"
,
"lv"
:
"Latvian"
,
"lt"
:
"Lithuanian"
,
"mk"
:
"Macedonian"
,
"ms"
:
"Malay"
,
"mr"
:
"Marathi"
,
"mi"
:
"Maori"
,
"ne"
:
"Nepali"
,
"no"
:
"Norwegian"
,
"fa"
:
"Persian"
,
"pl"
:
"Polish"
,
"pt"
:
"Portuguese"
,
"ro"
:
"Romanian"
,
"ru"
:
"Russian"
,
"sr"
:
"Serbian"
,
"sk"
:
"Slovak"
,
"sl"
:
"Slovenian"
,
"es"
:
"Spanish"
,
"sw"
:
"Swahili"
,
"sv"
:
"Swedish"
,
"tl"
:
"Tagalog"
,
"ta"
:
"Tamil"
,
"th"
:
"Thai"
,
"tr"
:
"Turkish"
,
"uk"
:
"Ukrainian"
,
"ur"
:
"Urdu"
,
"vi"
:
"Vietnamese"
,
"cy"
:
"Welsh"
,
}
class
WhisperPosEmbedType
(
enum
.
Enum
):
SINUSOIDAL
=
"sinusoidal"
NOPE
=
"nope"
LEARNED
=
"learned"
class
WhisperAudioInputs
(
TensorSchema
):
...
...
@@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
num_heads
:
int
,
bias
:
bool
=
True
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
per_layer_sliding_window
:
int
|
None
=
None
,
block_pool_size
:
int
=
1
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
...
...
@@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
attn_type
=
self
.
attn_type
,
)
else
:
# AttentionType.DECODER (regular decoder self-attention)
self
.
attn
=
Attention
(
if
block_pool_size
>
1
:
attn_cls
=
partial
(
WhisperAttentionWithBlockPooling
,
block_pool_size
=
block_pool_size
)
else
:
attn_cls
=
Attention
self
.
attn
=
attn_cls
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
@@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
self
.
attn_type
,
per_layer_sliding_window
=
per_layer_sliding_window
,
)
def
_init_qkv
(
...
...
@@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
is_causal
=
getattr
(
config
,
"is_causal"
,
False
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
block_pool_size
=
getattr
(
config
,
"block_pool_size"
,
1
)
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
...
...
@@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
self
.
self_attn
=
WhisperAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
encoder_attention_heads
,
attn_type
=
AttentionType
.
ENCODER
,
attn_type
=
AttentionType
.
DECODER
if
is_causal
else
AttentionType
.
ENCODER
,
block_pool_size
=
block_pool_size
,
per_layer_sliding_window
=
sliding_window
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
...
...
@@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
embed_dim
=
config
.
d_model
self
.
pos_embed_type
=
WhisperPosEmbedType
(
getattr
(
config
,
"pos_embed"
,
"sinusoidal"
)
)
self
.
num_mel_bins
=
config
.
num_mel_bins
self
.
max_source_positions
=
config
.
max_source_positions
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
config
.
scale_embedding
else
1.0
self
.
conv1
=
nn
.
Conv1d
(
self
.
num_mel_bins
,
embed_dim
,
kernel_size
=
3
,
padding
=
1
)
self
.
conv2
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
is_causal
=
getattr
(
config
,
"is_causal"
,
False
)
Conv1d
=
WhisperCausalConv1d
if
is_causal
else
partial
(
nn
.
Conv1d
,
padding
=
1
)
self
.
conv1
=
Conv1d
(
self
.
num_mel_bins
,
embed_dim
,
kernel_size
=
3
)
self
.
conv2
=
Conv1d
(
embed_dim
,
embed_dim
,
stride
=
2
,
kernel_size
=
3
)
self
.
total_stride
=
self
.
conv1
.
stride
[
0
]
*
self
.
conv2
.
stride
[
0
]
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
encoder_layers
,
lambda
prefix
:
WhisperEncoderLayer
(
...
...
@@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
)
if
is_causal
and
self
.
pos_embed_type
!=
WhisperPosEmbedType
.
NOPE
:
raise
ValueError
(
"Only NOPE position embeddings are supported "
f
"for causal models, but got
{
self
.
pos_embed_type
}
"
)
elif
self
.
pos_embed_type
in
(
WhisperPosEmbedType
.
SINUSOIDAL
,
WhisperPosEmbedType
.
LEARNED
,
):
maybe_fp32_init_ctx
=
(
set_default_torch_dtype
(
torch
.
float32
)
if
init_in_fp32
else
nullcontext
()
set_default_torch_dtype
(
torch
.
float32
)
if
init_in_fp32
else
nullcontext
()
)
with
(
torch
.
no_grad
(),
maybe_fp32_init_ctx
,
):
self
.
embed_positions
=
nn
.
Embedding
(
self
.
max_source_positions
,
embed_dim
)
self
.
embed_positions
=
nn
.
Embedding
(
self
.
max_source_positions
,
embed_dim
)
self
.
embed_positions
.
weight
.
copy_
(
sinusoids
(
*
self
.
embed_positions
.
weight
.
shape
)
)
def
forward
(
self
,
input_features
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]):
def
forward_conv
(
self
,
input_features
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
hidden_states
=
[]
input_is_batched
=
False
for
features
in
input_features
:
embeds
=
nn
.
functional
.
gelu
(
self
.
conv1
(
features
))
embeds
=
nn
.
functional
.
gelu
(
self
.
conv2
(
embeds
))
if
self
.
pos_embed_type
in
(
WhisperPosEmbedType
.
SINUSOIDAL
,
WhisperPosEmbedType
.
LEARNED
,
):
embeds
=
embeds
.
transpose
(
-
1
,
-
2
)
embeds
=
(
embeds
+
self
.
embed_positions
.
weight
[:
embeds
.
size
(
-
2
),
:]).
to
(
embeds
.
dtype
)
embeds
=
(
embeds
+
self
.
embed_positions
.
weight
[:
embeds
.
size
(
-
2
),
:]
).
to
(
embeds
.
dtype
)
elif
self
.
pos_embed_type
==
WhisperPosEmbedType
.
NOPE
:
embeds
=
embeds
.
transpose
(
-
1
,
-
2
).
to
(
embeds
.
dtype
)
else
:
raise
ValueError
(
f
"Unknown pos_embed_type:
{
self
.
pos_embed_type
}
"
)
hidden_states
.
append
(
embeds
)
input_is_batched
=
embeds
.
ndim
>
2
# Input to MHA must be B x T x D
...
...
@@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
else
:
hidden_states
=
torch
.
stack
(
hidden_states
,
dim
=
0
)
return
hidden_states
def
forward_layers
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
return
hidden_states
def
forward
(
self
,
input_features
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]):
hidden_states
=
self
.
forward_conv
(
input_features
)
return
self
.
forward_layers
(
hidden_states
)
class
WhisperDecoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/whisper_utils.py
0 → 100644
View file @
3faa8bee
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
functools
import
math
from
dataclasses
import
replace
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadata
,
AttentionType
,
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
subclass_attention_backend_with_overrides
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS
=
{
"af"
:
"Afrikaans"
,
"ar"
:
"Arabic"
,
"hy"
:
"Armenian"
,
"az"
:
"Azerbaijani"
,
"be"
:
"Belarusian"
,
"bs"
:
"Bosnian"
,
"bg"
:
"Bulgarian"
,
"ca"
:
"Catalan"
,
"zh"
:
"Chinese"
,
"hr"
:
"Croatian"
,
"cs"
:
"Czech"
,
"da"
:
"Danish"
,
"nl"
:
"Dutch"
,
"en"
:
"English"
,
"et"
:
"Estonian"
,
"fi"
:
"Finnish"
,
"fr"
:
"French"
,
"gl"
:
"Galician"
,
"de"
:
"German"
,
"el"
:
"Greek"
,
"he"
:
"Hebrew"
,
"hi"
:
"Hindi"
,
"hu"
:
"Hungarian"
,
"is"
:
"Icelandic"
,
"id"
:
"Indonesian"
,
"it"
:
"Italian"
,
"ja"
:
"Japanese"
,
"kn"
:
"Kannada"
,
"kk"
:
"Kazakh"
,
"ko"
:
"Korean"
,
"lv"
:
"Latvian"
,
"lt"
:
"Lithuanian"
,
"mk"
:
"Macedonian"
,
"ms"
:
"Malay"
,
"mr"
:
"Marathi"
,
"mi"
:
"Maori"
,
"ne"
:
"Nepali"
,
"no"
:
"Norwegian"
,
"fa"
:
"Persian"
,
"pl"
:
"Polish"
,
"pt"
:
"Portuguese"
,
"ro"
:
"Romanian"
,
"ru"
:
"Russian"
,
"sr"
:
"Serbian"
,
"sk"
:
"Slovak"
,
"sl"
:
"Slovenian"
,
"es"
:
"Spanish"
,
"sw"
:
"Swahili"
,
"sv"
:
"Swedish"
,
"tl"
:
"Tagalog"
,
"ta"
:
"Tamil"
,
"th"
:
"Thai"
,
"tr"
:
"Turkish"
,
"uk"
:
"Ukrainian"
,
"ur"
:
"Urdu"
,
"vi"
:
"Vietnamese"
,
"cy"
:
"Welsh"
,
}
def
_pad1d
(
x
:
torch
.
Tensor
,
paddings
:
tuple
[
int
,
int
],
mode
:
str
=
"constant"
,
value
:
float
=
0.0
,
)
->
torch
.
Tensor
:
"""Tiny wrapper around F.pad, just to allow for
reflect padding on small input.
If this is the case, we insert extra 0 padding
to the right before the reflection happen.
"""
length
=
x
.
shape
[
-
1
]
padding_left
,
padding_right
=
paddings
assert
padding_left
>=
0
and
padding_right
>=
0
,
(
padding_left
,
padding_right
)
if
mode
==
"reflect"
:
max_pad
=
max
(
padding_left
,
padding_right
)
extra_pad
=
0
if
length
<=
max_pad
:
extra_pad
=
max_pad
-
length
+
1
x
=
F
.
pad
(
x
,
(
0
,
extra_pad
))
padded
=
F
.
pad
(
x
,
paddings
,
mode
,
value
)
end
=
padded
.
shape
[
-
1
]
-
extra_pad
return
padded
[...,
:
end
]
else
:
return
F
.
pad
(
x
,
paddings
,
mode
,
value
)
class
WhisperCausalConv1d
(
nn
.
Conv1d
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
padding
:
int
=
0
,
bias
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias
=
bias
,
)
self
.
_stride
=
self
.
stride
[
0
]
self
.
_effective_kernel_size
=
(
kernel_size
-
1
)
*
self
.
dilation
[
0
]
+
1
self
.
_padding_total
=
self
.
_effective_kernel_size
-
self
.
_stride
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
n_frames
=
(
x
.
shape
[
-
1
]
-
self
.
_effective_kernel_size
+
self
.
_padding_total
)
/
self
.
_stride
+
1
target_length
=
(
math
.
ceil
(
n_frames
)
-
1
)
*
self
.
_stride
+
(
self
.
_effective_kernel_size
-
self
.
_padding_total
)
extra_padding
=
target_length
-
x
.
shape
[
-
1
]
x
=
_pad1d
(
x
,
(
self
.
_padding_total
,
extra_padding
),
mode
=
"constant"
)
return
super
().
forward
(
x
)
@
functools
.
lru_cache
def
create_whisper_attention_backend_with_block_pooling
(
underlying_attn_backend
:
AttentionBackend
,
block_pool_size
:
int
)
->
type
[
AttentionBackend
]:
prefix
=
"WhisperAttentionWithBlockPooling_"
underlying_builder
=
underlying_attn_backend
.
get_builder_cls
()
class
WhisperAttentionWithBlockPoolingBuilder
(
underlying_builder
):
# type: ignore
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
):
assert
kv_cache_spec
.
num_kv_heads
%
block_pool_size
==
0
kv_cache_spec
=
replace
(
kv_cache_spec
,
block_size
=
kv_cache_spec
.
block_size
*
block_pool_size
,
num_kv_heads
=
kv_cache_spec
.
num_kv_heads
//
block_pool_size
,
)
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
AttentionMetadata
:
new_common_attn_metadata
=
copy
.
deepcopy
(
common_attn_metadata
)
new_common_attn_metadata
.
query_start_loc
*=
block_pool_size
new_common_attn_metadata
.
query_start_loc_cpu
*=
block_pool_size
new_common_attn_metadata
.
seq_lens
*=
block_pool_size
new_common_attn_metadata
.
_seq_lens_cpu
*=
block_pool_size
new_common_attn_metadata
.
_num_computed_tokens_cpu
*=
block_pool_size
new_common_attn_metadata
.
num_actual_tokens
*=
block_pool_size
new_common_attn_metadata
.
max_query_len
*=
block_pool_size
new_common_attn_metadata
.
max_seq_len
*=
block_pool_size
original_slot_mapping
=
common_attn_metadata
.
slot_mapping
common_prefix_len
*=
block_pool_size
new_common_attn_metadata
.
slot_mapping
=
(
(
original_slot_mapping
.
unsqueeze
(
1
)
*
block_pool_size
+
torch
.
arange
(
block_pool_size
,
device
=
original_slot_mapping
.
device
)
)
.
flatten
()
.
clamp
(
min
=-
1
)
)
return
super
().
build
(
common_prefix_len
,
new_common_attn_metadata
,
fast_build
)
if
not
issubclass
(
underlying_attn_backend
,
FlashAttentionBackend
):
raise
NotImplementedError
(
f
"
{
underlying_attn_backend
}
is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
attn_backend
=
subclass_attention_backend_with_overrides
(
name_prefix
=
prefix
,
attention_backend_cls
=
underlying_attn_backend
,
overrides
=
{
"get_builder_cls"
:
lambda
:
WhisperAttentionWithBlockPoolingBuilder
,
"get_kv_cache_shape"
:
lambda
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
cache_dtype_str
:
(
2
,
num_blocks
,
# we stretch each block by `block_pool_size`
block_size
*
block_pool_size
,
num_kv_heads
//
block_pool_size
,
head_size
,
),
# TODO: generalize to other backends
},
)
return
attn_backend
class
WhisperAttentionWithBlockPooling
(
Attention
):
"""Attention layer with block pooling."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
alibi_slopes
:
list
[
float
]
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
logits_soft_cap
:
float
|
None
=
None
,
per_layer_sliding_window
:
int
|
None
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
block_pool_size
:
int
=
1
,
attn_backend
:
type
[
AttentionBackend
]
|
None
=
None
,
**
extra_impl_args
,
)
->
None
:
self
.
block_pool_size
=
block_pool_size
dtype
=
torch
.
get_default_dtype
()
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
underlying_attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
attn_type
=
attn_type
,
)
attn_backend
=
create_whisper_attention_backend_with_block_pooling
(
underlying_attn_backend
,
block_pool_size
)
super
().
__init__
(
num_heads
=
num_heads
,
head_size
=
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
alibi_slopes
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
logits_soft_cap
=
logits_soft_cap
,
per_layer_sliding_window
=
per_layer_sliding_window
,
prefix
=
prefix
,
attn_type
=
attn_type
,
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
,
attn_backend
=
attn_backend
,
**
extra_impl_args
,
)
def
get_kv_cache_spec
(
self
,
vllm_config
:
VllmConfig
):
kv_cache_spec
=
super
().
get_kv_cache_spec
(
vllm_config
)
assert
isinstance
(
kv_cache_spec
,
AttentionSpec
)
kv_cache_spec
=
replace
(
kv_cache_spec
,
num_kv_heads
=
self
.
block_pool_size
*
kv_cache_spec
.
num_kv_heads
,
)
return
kv_cache_spec
vllm/transformers_utils/configs/mistral.py
View file @
3faa8bee
...
...
@@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args
=
config
[
"multimodal"
].
pop
(
"whisper_model_args"
)
encoder_args
=
whisper_args
[
"encoder_args"
]
downsample_args
=
whisper_args
[
"downsample_args"
]
downsample_factor
=
downsample_args
[
"downsample_factor"
]
# make sure that k/v blocks can be allocated with
# unified k/v cache class and pool whisper k/v cache blocks
# with downsample_factor:1 ratio
if
encoder_args
.
get
(
"causal"
):
block_pool_size
=
downsample_factor
config
[
"projection_size"
]
=
downsample_factor
*
encoder_args
[
"dim"
]
else
:
block_pool_size
=
1
_maybe_sliding_window
=
encoder_args
.
get
(
"ragged_attention"
,
None
)
if
_maybe_sliding_window
is
None
:
sliding_window
=
None
elif
_maybe_sliding_window
.
isdigit
():
sliding_window
=
int
(
_maybe_sliding_window
)
else
:
raise
NotImplementedError
(
f
"Unsupported:
{
_maybe_sliding_window
=
}
"
)
architecture
=
(
"VoxtralStreamingGeneration"
if
encoder_args
.
get
(
"causal"
)
else
"VoxtralForConditionalGeneration"
)
quant_config
=
config
.
get
(
"quantization_config"
)
config
=
{
"model_type"
:
"
whi
xtral"
,
"architectures"
:
[
"VoxtralForConditionalGeneration"
],
"model_type"
:
"
vo
xtral"
,
"architectures"
:
[
architecture
],
"text_config"
:
PretrainedConfig
.
from_dict
(
config
),
"audio_config"
:
WhisperConfig
(
num_mel_bins
=
encoder_args
[
"audio_encoding_args"
][
"num_mel_bins"
],
window_size
=
encoder_args
[
"audio_encoding_args"
][
"window_size"
],
sampling_rate
=
encoder_args
[
"audio_encoding_args"
][
"sampling_rate"
],
hop_length
=
encoder_args
[
"audio_encoding_args"
][
"hop_length"
],
downsample_factor
=
downsample_
args
[
"downsample_
factor
"
]
,
downsample_factor
=
downsample_factor
,
d_model
=
encoder_args
[
"dim"
],
encoder_layers
=
encoder_args
[
"n_layers"
],
encoder_ffn_dim
=
encoder_args
[
"hidden_dim"
],
...
...
@@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
vocab_size
=
encoder_args
[
"vocab_size"
],
max_source_positions
=
encoder_args
[
"max_source_positions"
],
is_encoder_decoder
=
False
,
# Override WhisperConfig default
is_causal
=
encoder_args
.
get
(
"causal"
,
False
),
sliding_window
=
sliding_window
,
block_pool_size
=
block_pool_size
,
pos_embed
=
encoder_args
.
get
(
"pos_embed"
,
"sinusoidal"
),
),
}
if
quant_config
:
...
...
vllm/v1/attention/backends/utils.py
View file @
3faa8bee
...
...
@@ -835,6 +835,15 @@ def subclass_attention_backend(
)
def
subclass_attention_backend_with_overrides
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
overrides
:
dict
[
str
,
Any
],
)
->
type
[
AttentionBackend
]:
name
:
str
=
name_prefix
+
attention_backend_cls
.
__name__
# type: ignore
return
type
(
name
,
(
attention_backend_cls
,),
overrides
)
def
split_decodes_prefills_and_extends
(
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3faa8bee
...
...
@@ -2457,6 +2457,17 @@ class GPUModelRunner(
return
round_up
(
num_scheduled_tokens
,
tp_size
)
return
num_scheduled_tokens
def
_prepare_mm_inputs
(
self
,
num_tokens
:
int
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
]:
if
self
.
model
.
requires_raw_input_tokens
:
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens
]
else
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens
]
return
input_ids
,
inputs_embeds
def
_preprocess
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -2499,8 +2510,7 @@ class GPUModelRunner(
# TODO(woosuk): Avoid the copy. Optimize.
self
.
inputs_embeds
.
gpu
[:
num_scheduled_tokens
].
copy_
(
inputs_embeds_scheduled
)
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_input_tokens
]
input_ids
,
inputs_embeds
=
self
.
_prepare_mm_inputs
(
num_input_tokens
)
model_kwargs
=
{
**
self
.
_init_model_kwargs
(
num_scheduled_tokens
),
**
self
.
_extract_mm_kwargs
(
scheduler_output
),
...
...
@@ -4220,8 +4230,8 @@ class GPUModelRunner(
assert
num_tokens_padded
<=
self
.
max_num_tokens
model_kwargs
=
self
.
_init_model_kwargs
(
num_tokens_padded
)
if
self
.
supports_mm_inputs
and
not
self
.
model_config
.
is_encoder_decoder
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
input_ids
,
inputs_embeds
=
self
.
_prepare_mm_inputs
(
num_tokens_padded
)
model_kwargs
=
{
**
model_kwargs
,
**
self
.
_dummy_mm_kwargs
(
num_reqs
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment