Unverified Commit 3f3f8952 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files
parent 5da4c7d7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
import pytest
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
StreamingMode,
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset
def _get_engine(path: str) -> LLM:
engine_args = EngineArgs(
model=path,
max_model_len=8192,
max_num_seqs=1,
limit_mm_per_prompt={"audio": 1},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
gpu_memory_utilization=0.4,
)
return LLM(**asdict(engine_args))
@pytest.mark.skip(reason="Voxtral streaming is not yet public")
def test_voxtral_streaming_forward():
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
model_name = "mistralai/Voxtral-Mini-3B-Realtime-2602"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
audio_config = tokenizer.instruct_tokenizer.tokenizer.audio
def from_file(file_path: str):
audio = Audio.from_file(file_path, strict=False)
req = TranscriptionRequest(
audio=RawAudio.from_audio(audio),
streaming=StreamingMode.OFFLINE,
language=None,
)
tokenized = tokenizer.instruct_tokenizer.encode_transcription(req)
return (tokenized.tokens, tokenized.audios[0].audio_array)
tokenized_list = [
from_file(audio_asset.get_local_path()) for audio_asset in audio_assets
]
inputs = []
sampling_params = []
for tokens, audio_array in tokenized_list:
num_samples = audio_array.shape[0]
max_tokens = (
audio_config.num_audio_tokens(num_samples)
- audio_config.num_delay_tokens
- 1
)
sampling_params.append(SamplingParams(temperature=0.0, max_tokens=max_tokens))
input_dict = {
"multi_modal_data": {"audio": [(audio_array, None)]},
"prompt_token_ids": tokens,
}
inputs.append(input_dict)
llm = _get_engine(model_name)
outputs = llm.generate(
inputs,
sampling_params=sampling_params,
)
texts = [out.outputs[0].text for out in outputs]
expected = [
(
" First words I spoke in the original phonograph. "
"A little piece of practical poetry. Mary had a little lamb,"
" it sleeps with quite a snow, and everywhere that Mary went, "
"the lamb was sure to go."
),
(
" And the 0-1 pitch on the way to Edgar Martinez. Swung on"
" the line. Down the left field line for OBS. Here comes Joy. "
"Here is Junior to third base. They're going to wave him in. "
"The throw to the plate will be late. The Mariners are going"
" to play. For the American League Championship, "
"I don't believe it. It just continues. My oh, my."
),
]
assert texts == expected
...@@ -404,6 +404,7 @@ class LlamaModel(nn.Module): ...@@ -404,6 +404,7 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**extra_layer_kwargs,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
...@@ -422,7 +423,9 @@ class LlamaModel(nn.Module): ...@@ -422,7 +423,9 @@ class LlamaModel(nn.Module):
): ):
if idx in self.aux_hidden_state_layers: if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual) aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, residual = layer(
positions, hidden_states, residual, **extra_layer_kwargs
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
......
...@@ -10,6 +10,12 @@ from transformers import LlamaConfig ...@@ -10,6 +10,12 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llama import ( from vllm.model_executor.models.llama import (
LlamaAttention, LlamaAttention,
...@@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import ( ...@@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import (
LlamaForCausalLM, LlamaForCausalLM,
LlamaModel, LlamaModel,
) )
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .utils import AutoWeightsLoader from .utils import AutoWeightsLoader
class MistralMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
gate_up_proj_bias: bool | None = None,
prefix: str = "",
reduce_results: bool = True,
disable_tp: bool = False,
) -> None:
super().__init__()
gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=gate_up_proj_bias,
quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class MistralAttention(LlamaAttention): class MistralAttention(LlamaAttention):
def __init__( def __init__(
self, self,
...@@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer): ...@@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer):
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
if getattr(config, "ada_rms_norm_t_cond", False):
self.ada_rms_norm_t_cond = nn.Sequential(
ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.ada_rms_norm_t_cond_dim,
bias=False,
return_bias=False,
),
nn.GELU(),
RowParallelLinear(
input_size=config.ada_rms_norm_t_cond_dim,
output_size=config.hidden_size,
bias=False,
return_bias=False,
),
)
else:
self.ada_rms_norm_t_cond = None
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
t_cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.ada_rms_norm_t_cond is not None:
assert t_cond is not None
hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond))
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile @support_torch_compile
class MistralModel(LlamaModel): class MistralModel(LlamaModel):
...@@ -126,6 +222,18 @@ class MistralModel(LlamaModel): ...@@ -126,6 +222,18 @@ class MistralModel(LlamaModel):
): ):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
t_cond: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
return super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond
)
class MistralForCausalLM(LlamaForCausalLM): class MistralForCausalLM(LlamaForCausalLM):
# Mistral: We don't support LoRA on the embedding layers. # Mistral: We don't support LoRA on the embedding layers.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import inspect import inspect
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property, partial
from math import ceil from math import ceil
from typing import Literal, cast from typing import Literal, cast
...@@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import WhisperEncoder from vllm.model_executor.models.whisper import (
WhisperEncoder,
_create_fake_bias_for_k_proj,
)
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration( ...@@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration(
} }
).named_parameters() ).named_parameters()
) )
weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
loaded_weights = set() loaded_weights = set()
...@@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module): ...@@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module):
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501 r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc2.\2", r"whisper_encoder.layers.\1.mlp.fc2.\2",
), ),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
r"whisper_encoder.layers.\1.mlp.fc3.\2",
), # noqa: E501
( (
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
r"whisper_encoder.layers.\1.final_layer_norm.\2", r"whisper_encoder.layers.\1.final_layer_norm.\2",
...@@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module): ...@@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module):
super().__init__() super().__init__()
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config) self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
self.dtype: torch.dtype = vllm_config.model_config.dtype self.dtype: torch.dtype = vllm_config.model_config.dtype
self.whisper_encoder = WhisperEncoder( self.is_causal = getattr(self.config, "is_causal", False)
if self.is_causal:
WhisperEncoderCls = WhisperCausalEncoder
else:
WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)
self.whisper_encoder = WhisperEncoderCls(
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "whisper_encoder"), prefix=maybe_prefix(prefix, "whisper_encoder"),
init_in_fp32=True,
) )
mel_filters = mel_filter_bank( mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2, num_frequency_bins=1 + self.config.window_size // 2,
...@@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module): ...@@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module):
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_mapping = []
if self.is_causal:
# For `WhisperCausalEncoder` we need
# some more renaming
stacked_params_mapping.extend(
[
(".mlp.gate_up_proj", ".mlp.fc1", 0),
(".mlp.gate_up_proj", ".mlp.fc3", 1),
]
)
params_mapping.extend(
[
(".mlp.down_proj", ".mlp.fc2"),
]
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
name, loaded_weight = weight name, loaded_weight = weight
...@@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module): ...@@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for param_name, weight_name in params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -112,6 +112,18 @@ class TimeEmbedding(torch.nn.Module): ...@@ -112,6 +112,18 @@ class TimeEmbedding(torch.nn.Module):
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D) return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
# 1. Multiply by the scaling factor (e.g. 4)
base = input_tensor * scaling
# 2. Create the offsets, e.g. [0, 1, 2, 3]
offsets = torch.arange(scaling, device=input_tensor.device)
# 3. Use broadcasting, e.g. (N, 1) + (4,) results in (N, 4)
# Then flatten back to 1D
return (base.unsqueeze(1) + offsets).view(-1)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
VoxtralStreamingMultiModalProcessor, VoxtralStreamingMultiModalProcessor,
info=VoxtralProcessingInfo, info=VoxtralProcessingInfo,
...@@ -175,8 +187,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -175,8 +187,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
) )
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers( whisper_positions = _expand_tensor(positions, pool_size)
inputs_embeds audio_hidden_states = self.whisper_encoder.whisper_encoder(
inputs_embeds, whisper_positions
) )
num_tokens, audio_hidden_size = audio_hidden_states.shape num_tokens, audio_hidden_size = audio_hidden_states.shape
...@@ -197,10 +210,14 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -197,10 +210,14 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
device=inputs_embeds.device, device=inputs_embeds.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor) t_cond = self.time_embedding(time_tensor)
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
t_cond=t_cond,
) )
return hidden_states return hidden_states
......
...@@ -5,7 +5,6 @@ import enum ...@@ -5,7 +5,6 @@ import enum
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
from typing import Annotated, Literal, cast from typing import Annotated, Literal, cast
import numpy as np import numpy as np
...@@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import ( from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS, ISO639_1_SUPPORTED_LANGS,
WhisperAttentionWithBlockPooling,
WhisperCausalConv1d,
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
...@@ -78,7 +75,7 @@ logger = init_logger(__name__) ...@@ -78,7 +75,7 @@ logger = init_logger(__name__)
class WhisperPosEmbedType(enum.Enum): class WhisperPosEmbedType(enum.Enum):
SINUSOIDAL = "sinusoidal" SINUSOIDAL = "sinusoidal"
NOPE = "nope" ROPE = "rope"
LEARNED = "learned" LEARNED = "learned"
...@@ -140,7 +137,6 @@ class WhisperAttention(nn.Module): ...@@ -140,7 +137,6 @@ class WhisperAttention(nn.Module):
bias: bool = True, bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
per_layer_sliding_window: int | None = None, per_layer_sliding_window: int | None = None,
block_pool_size: int = 1,
cache_config: CacheConfig | None = None, cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
...@@ -199,14 +195,7 @@ class WhisperAttention(nn.Module): ...@@ -199,14 +195,7 @@ class WhisperAttention(nn.Module):
attn_type=self.attn_type, attn_type=self.attn_type,
) )
else: # AttentionType.DECODER (regular decoder self-attention) else: # AttentionType.DECODER (regular decoder self-attention)
if block_pool_size > 1: self.attn = Attention(
attn_cls = partial(
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
)
else:
attn_cls = Attention
self.attn = attn_cls(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.scaling, self.scaling,
...@@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module): ...@@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
is_causal = getattr(config, "is_causal", False)
sliding_window = getattr(config, "sliding_window", None) sliding_window = getattr(config, "sliding_window", None)
block_pool_size = getattr(config, "block_pool_size", 1)
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module): ...@@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module):
self.self_attn = WhisperAttention( self.self_attn = WhisperAttention(
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads, num_heads=config.encoder_attention_heads,
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER, attn_type=AttentionType.ENCODER,
block_pool_size=block_pool_size,
per_layer_sliding_window=sliding_window, per_layer_sliding_window=sliding_window,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module): ...@@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module):
self.max_source_positions = config.max_source_positions self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.is_causal = getattr(config, "is_causal", False) self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
Conv1d = ( self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0] self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
...@@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module): ...@@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module):
) )
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: if self.pos_embed_type not in (
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
)
elif self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL, WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED, WhisperPosEmbedType.LEARNED,
): ):
maybe_fp32_init_ctx = ( raise ValueError(
set_default_torch_dtype(torch.float32) "Only sinusoidal or learned position embeddings are supported "
if init_in_fp32 f"for non-causal models, but got {self.pos_embed_type}"
else nullcontext() )
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
)
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
) )
with ( def forward(
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(
self.max_source_positions, embed_dim
)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
)
def forward_conv(
self, input_features: torch.Tensor | list[torch.Tensor] self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = [] hidden_states = []
...@@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module): ...@@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module):
embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds)) embeds = nn.functional.gelu(self.conv2(embeds))
if self.pos_embed_type in ( embeds = embeds.transpose(-1, -2)
WhisperPosEmbedType.SINUSOIDAL, embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
WhisperPosEmbedType.LEARNED, embeds.dtype
): )
embeds = embeds.transpose(-1, -2)
embeds = (
embeds + self.embed_positions.weight[: embeds.size(-2), :]
).to(embeds.dtype)
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
else:
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
hidden_states.append(embeds) hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2 input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D # Input to MHA must be B x T x D
if input_is_batched or self.is_causal: if input_is_batched:
# Models using WhisperEncoder may handle batching internally. # Models using WhisperEncoder may handle batching internally.
# If WhisperEncoder is causal, sequences
# are not padded to have identical seq length (T)
# => concat over feature dim
hidden_states = torch.cat(hidden_states) hidden_states = torch.cat(hidden_states)
else: else:
hidden_states = torch.stack(hidden_states, dim=0) hidden_states = torch.stack(hidden_states, dim=0)
return hidden_states
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
for encoder_layer in self.layers: for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states) hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = self.forward_conv(input_features)
return self.forward_layers(hidden_states)
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1}) @support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
class WhisperDecoder(nn.Module): class WhisperDecoder(nn.Module):
...@@ -978,19 +937,19 @@ class WhisperForConditionalGeneration( ...@@ -978,19 +937,19 @@ class WhisperForConditionalGeneration(
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
# add fake zeros bias for k_proj to state_dict # add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights) weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj( def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
) -> Iterable[tuple[str, torch.Tensor]]: ) -> Iterable[tuple[str, torch.Tensor]]:
""" """
Create full zeros bias for k_proj weight in self-attn and x-attn layers. Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros. So that the bias for k_proj in qkv_proj can be initialized with zeros.
""" """
for name, weight in weights: for name, weight in weights:
if name.endswith(".k_proj.weight"): if name.endswith(fake_bias_key_name):
bias = torch.zeros(weight.size(0)) bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias") bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)] yield from [(name, weight), (bias_name, bias)]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import functools
import math
from dataclasses import replace
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.mistral import MistralMLP
from vllm.model_executor.models.whisper import WhisperPosEmbedType
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec
from .utils import make_layers
CausalRMSNorm = partial(RMSNorm, eps=1e-5)
def _pad1d(
x: torch.Tensor,
paddings: tuple[int, int],
mode: str = "constant",
value: float = 0.0,
) -> torch.Tensor:
"""Tiny wrapper around F.pad, just to allow for
reflect padding on small input.
If this is the case, we insert extra 0 padding
to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class WhisperCausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._stride = self.stride[0]
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
self._padding_total = self._effective_kernel_size - self._stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_frames = (
x.shape[-1] - self._effective_kernel_size + self._padding_total
) / self._stride + 1
target_length = (math.ceil(n_frames) - 1) * self._stride + (
self._effective_kernel_size - self._padding_total
)
extra_padding = target_length - x.shape[-1]
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
return super().forward(x)
@functools.lru_cache
def create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend: AttentionBackend, block_pool_size: int
) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
kv_cache_spec = replace(
kv_cache_spec,
block_size=kv_cache_spec.block_size * block_pool_size,
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
)
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size
original_slot_mapping = common_attn_metadata.slot_mapping
common_prefix_len *= block_pool_size
new_common_attn_metadata.slot_mapping = (
(
original_slot_mapping.unsqueeze(1) * block_pool_size
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
)
.flatten()
.clamp(min=-1)
)
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
head_size,
cache_dtype_str: (
2,
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
},
)
return attn_backend
class WhisperCausalAttentionWithBlockPooling(Attention):
"""Attention layer with block pooling."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
block_pool_size: int = 1,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
self.block_pool_size = block_pool_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend, block_pool_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=prefix,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
**extra_impl_args,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig):
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
assert isinstance(kv_cache_spec, AttentionSpec)
kv_cache_spec = replace(
kv_cache_spec,
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
)
return kv_cache_spec
class WhisperCausalAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
head_dim: int,
max_position_embeddings: int,
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
per_layer_sliding_window: int | None = None,
block_pool_size: int = 1,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_heads >= tp_size:
# Number of heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_heads % tp_size == 0
else:
# Number of heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_heads == 0
self.num_kv_heads = max(1, self.total_num_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn_type = attn_type
self.scaling = self.head_dim**-0.5
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
self.out_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
assert block_pool_size > 1, (
f"Causal attention only supports block_pool_size>1, not {block_pool_size}."
)
self.attn = WhisperCausalAttentionWithBlockPooling(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
per_layer_sliding_window=per_layer_sliding_window,
block_pool_size=block_pool_size,
)
assert per_layer_sliding_window is not None, (
"rope can only used in combination with a sliding window"
)
self._init_rotary_emb(max_position_embeddings)
def _init_rotary_emb(self, max_position_embeddings: int) -> None:
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
is_neox_style=False,
rope_parameters={"rope_theta": 1e6},
)
def _init_qkv(
self,
embed_dim: int,
bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor | None = None,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
assert positions is not None
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class WhisperCausalEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
sliding_window = getattr(config, "sliding_window", None)
block_pool_size = config.block_pool_size
assert block_pool_size > 1
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_dim = config.d_model
self.head_dim = self.embed_dim // config.encoder_attention_heads
self.self_attn = WhisperCausalAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
head_dim=config.encoder_head_dim,
max_position_embeddings=config.max_position_embeddings,
block_pool_size=block_pool_size,
per_layer_sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = CausalRMSNorm(self.embed_dim)
self.mlp = MistralMLP(
hidden_size=config.d_model,
intermediate_size=config.encoder_ffn_dim,
hidden_act="silu",
quant_config=quant_config,
bias=True,
gate_up_proj_bias=False,
prefix=f"{prefix}.mlp",
)
self.final_layer_norm = CausalRMSNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor | None = None,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, positions=positions)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class WhisperCausalEncoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
assert WhisperPosEmbedType(config.pos_embed) == WhisperPosEmbedType.ROPE
assert config.is_causal
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = WhisperCausalConv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = WhisperCausalConv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperCausalEncoderLayer(
vllm_config=vllm_config, prefix=f"{prefix}.layers"
),
prefix=f"{prefix}.layers",
)
self.layer_norm = CausalRMSNorm(config.d_model)
def forward_conv(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor:
hidden_states = []
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
return hidden_states
def forward(
self, hidden_states: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, positions)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import functools
import math
from dataclasses import replace
import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages # From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = { ISO639_1_SUPPORTED_LANGS = {
...@@ -83,215 +62,3 @@ ISO639_1_SUPPORTED_LANGS = { ...@@ -83,215 +62,3 @@ ISO639_1_SUPPORTED_LANGS = {
"vi": "Vietnamese", "vi": "Vietnamese",
"cy": "Welsh", "cy": "Welsh",
} }
def _pad1d(
x: torch.Tensor,
paddings: tuple[int, int],
mode: str = "constant",
value: float = 0.0,
) -> torch.Tensor:
"""Tiny wrapper around F.pad, just to allow for
reflect padding on small input.
If this is the case, we insert extra 0 padding
to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class WhisperCausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
self._stride = self.stride[0]
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
self._padding_total = self._effective_kernel_size - self._stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_frames = (
x.shape[-1] - self._effective_kernel_size + self._padding_total
) / self._stride + 1
target_length = (math.ceil(n_frames) - 1) * self._stride + (
self._effective_kernel_size - self._padding_total
)
extra_padding = target_length - x.shape[-1]
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
return super().forward(x)
@functools.lru_cache
def create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend: AttentionBackend, block_pool_size: int
) -> type[AttentionBackend]:
prefix = "WhisperAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
kv_cache_spec = replace(
kv_cache_spec,
block_size=kv_cache_spec.block_size * block_pool_size,
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
)
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
new_common_attn_metadata.query_start_loc *= block_pool_size
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
new_common_attn_metadata.seq_lens *= block_pool_size
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
new_common_attn_metadata.num_actual_tokens *= block_pool_size
new_common_attn_metadata.max_query_len *= block_pool_size
new_common_attn_metadata.max_seq_len *= block_pool_size
original_slot_mapping = common_attn_metadata.slot_mapping
common_prefix_len *= block_pool_size
new_common_attn_metadata.slot_mapping = (
(
original_slot_mapping.unsqueeze(1) * block_pool_size
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
)
.flatten()
.clamp(min=-1)
)
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
head_size,
cache_dtype_str: (
2,
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
},
)
return attn_backend
class WhisperAttentionWithBlockPooling(Attention):
"""Attention layer with block pooling."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
block_pool_size: int = 1,
attn_backend: type[AttentionBackend] | None = None,
**extra_impl_args,
) -> None:
self.block_pool_size = block_pool_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(
underlying_attn_backend, block_pool_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=logits_soft_cap,
per_layer_sliding_window=per_layer_sliding_window,
prefix=prefix,
attn_type=attn_type,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
**extra_impl_args,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig):
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
assert isinstance(kv_cache_spec, AttentionSpec)
kv_cache_spec = replace(
kv_cache_spec,
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
)
return kv_cache_spec
...@@ -224,6 +224,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: ...@@ -224,6 +224,7 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_layers=encoder_args["n_layers"], encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"], encoder_ffn_dim=encoder_args["hidden_dim"],
encoder_attention_heads=encoder_args["n_heads"], encoder_attention_heads=encoder_args["n_heads"],
encoder_head_dim=encoder_args["head_dim"],
vocab_size=encoder_args["vocab_size"], vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"], max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default is_encoder_decoder=False, # Override WhisperConfig default
...@@ -231,6 +232,8 @@ def _remap_mistral_audio_args(config: dict) -> dict: ...@@ -231,6 +232,8 @@ def _remap_mistral_audio_args(config: dict) -> dict:
sliding_window=sliding_window, sliding_window=sliding_window,
block_pool_size=block_pool_size, block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"), pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
# only needed for RoPE
max_position_embeddings=block_pool_size * config["max_position_embeddings"],
), ),
} }
if quant_config: if quant_config:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment