# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, TypeAlias, cast import numpy as np import torch import torch.nn as nn from transformers import BatchFeature from transformers.models.glmasr import GlmAsrConfig, GlmAsrEncoder, GlmAsrProcessor from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.inputs.data import PromptType from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( ColumnParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( DictEmbeddingItems, ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, ) from vllm.multimodal.processing import ( PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape from .audioflamingo3 import ( AudioFlamingo3MultiModalDataParser, AudioFlamingo3MultiModalProcessor, AudioFlamingo3ProcessingInfo, ) from .audioflamingo3 import ( _audioflamingo3_field_config as _glmasr_field_config, ) from .glmasr_utils import ( DEFAULT_CONV_PARAMS, DEFAULT_MAX_AUDIO_LEN_S, DEFAULT_MERGE_FACTOR, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, _get_num_features_for_item, _group_audio_embeddings, _normalize_chunk_counts, ) from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsTranscription, ) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .whisper import ISO639_1_SUPPORTED_LANGS class GlmAsrFeatureInputs(TensorSchema): """ Dimensions: - num_chunks: Number of audio chunks (flattened) - nmb: Number of mel bins - num_audios: Number of original audio files """ type: Literal["audio_features"] input_features: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}), ] feature_attention_mask: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}), ] chunk_counts: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("num_audios"), ] class GlmAsrEmbeddingInputs(TensorSchema): """ Dimensions: - bn: Batch size - naf: Number of audio features - hs: Hidden size (must match the hidden size of language model backbone) """ type: Literal["audio_embeds"] = "audio_embeds" audio_embeds: Annotated[ list[torch.Tensor], TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}), ] GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs class GlmAsrMultiModalProjector(nn.Module): def __init__( self, config: GlmAsrConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.linear_1 = ColumnParallelLinear( input_size=config.audio_config.intermediate_size, output_size=config.text_config.hidden_size * 2, quant_config=quant_config, prefix=f"{prefix}.linear_1", ) self.act = get_act_fn(config.projector_hidden_act) self.linear_2 = RowParallelLinear( input_size=config.text_config.hidden_size * 2, output_size=config.text_config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.linear_2", ) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(audio_features) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_2(hidden_states) return hidden_states class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo): def get_hf_config(self) -> GlmAsrConfig: return self.ctx.get_hf_config(GlmAsrConfig) def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor: return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs) def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: # Reuse parent implementation, but add type annotation and assertion feature_extractor = super().get_feature_extractor(**kwargs) assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) hf_processor = self.info.get_hf_processor() return hf_processor.audio_token * num_audios def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate num_audios = mm_counts.get("audio", 0) audio_overrides = mm_options.get("audio") if mm_options else None max_audio_len = getattr( self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S ) audio_len = int(max_audio_len * sampling_rate) return { "audio": self._get_dummy_audios( length=audio_len, num_audios=num_audios, overrides=audio_overrides ) } class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser): def _parse_audio_data( self, data: dict[str, torch.Tensor] | ModalityData[Any], ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="audio", required_fields={"audio_embeds"}, fields_factory=_glmasr_field_config, ) return super()._parse_audio_data(data) class GlmAsrMultiModalProcessor(AudioFlamingo3MultiModalProcessor): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _calculate_chunk_counts( self, audio_list: list[Any], feature_extractor: WhisperFeatureExtractor, processor: GlmAsrProcessor, ) -> list[int]: """Calculate chunk counts for each audio.""" sampling_rate = feature_extractor.sampling_rate chunk_length = feature_extractor.chunk_length max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) window_size = int(sampling_rate * chunk_length) max_windows = int(max_audio_len // chunk_length) chunk_counts = [] for audio in audio_list: n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] n_chunks = max(1, (n_samples + window_size - 1) // window_size) chunk_counts.append(min(n_chunks, max_windows)) return chunk_counts def _call_hf_processor( self, prompt: str, mm_data: dict[str, object], mm_kwargs: Mapping[str, Any], tok_kwargs: Mapping[str, object], ) -> BatchFeature: # Normalize input: handle deprecated key and list conversion. if "audios" in mm_data: mm_data["audio"] = mm_data.pop("audios") audio = mm_data.get("audio", []) audio_list = [audio] if audio and not isinstance(audio, list) else audio # Early return for text-only. if not audio_list: prompt_ids = self.info.get_tokenizer().encode(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") # Get processor for chunk counts calculation processor = self.info.get_hf_processor(**mm_kwargs) # Call parent method (it will handle sampling_rate) outputs = super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) # Postprocess: rename mask and add chunk counts. if "input_features_mask" in outputs: outputs["feature_attention_mask"] = outputs.pop("input_features_mask") # Override chunk counts calculation with GLM-ASR specific logic chunk_counts = self._calculate_chunk_counts( audio_list, processor.feature_extractor, processor ) outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) return outputs def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _glmasr_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() config = self.info.get_hf_config() audio_token = getattr(processor, "audio_token", "<|pad|>") audio_token_id = vocab.get(audio_token) if audio_token_id is None: audio_token_id = processor.audio_token_id merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR) out_mm_data = out_mm_kwargs.get_data() feature_attention_mask = out_mm_data.get("feature_attention_mask") chunk_counts = out_mm_data.get("chunk_counts") def get_replacement_glmasr(item_idx: int): conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS) audio_embeds = out_mm_data.get("audio_embeds") num_features = _get_num_features_for_item( feature_attention_mask, chunk_counts, item_idx, audio_embeds, merge_factor, conv_params, ) if num_features == 0: raise ValueError("Audio is too short") audio_tokens = [audio_token_id] * int(num_features) return PromptUpdateDetails.select_token_id( audio_tokens, embed_token_id=audio_token_id, ) return [ PromptReplacement( modality="audio", target=audio_token, replacement=get_replacement_glmasr, ) ] @MULTIMODAL_REGISTRY.register_processor( GlmAsrMultiModalProcessor, info=GlmAsrProcessingInfo, dummy_inputs=GlmAsrDummyInputsBuilder, ) class GlmAsrForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription ): supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.audio_tower = GlmAsrEncoder(config.audio_config) self.multi_modal_projector = GlmAsrMultiModalProjector( config, quant_config=quant_config, prefix=maybe_prefix(prefix, "multi_modal_projector"), ) self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), architectures=["LlamaForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return "<|begin_of_audio|><|pad|><|end_of_audio|>" raise ValueError("Only audio modality is supported") def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model.", connector="multi_modal_projector.", tower_model="audio_tower.", ) def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None: audio_embeds = kwargs.pop("audio_embeds", None) if audio_embeds is not None: return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds) input_features = kwargs.pop("input_features", None) if input_features is None: return None return GlmAsrFeatureInputs( type="audio_features", input_features=input_features, feature_attention_mask=kwargs.pop("feature_attention_mask", None), chunk_counts=kwargs.pop("chunk_counts", None), ) def _process_audio_input( self, audio_input: GlmAsrInputs ) -> torch.Tensor | tuple[torch.Tensor, ...]: if audio_input["type"] == "audio_embeds": return tuple(audio_input["audio_embeds"]) input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] if isinstance(input_features, list): input_features = torch.cat(input_features, dim=0) feature_attention_mask = torch.cat(feature_attention_mask, dim=0) num_chunks = input_features.shape[0] chunk_counts = _normalize_chunk_counts( audio_input.get("chunk_counts"), num_chunks=num_chunks ) audio_hidden_states = self.audio_tower(input_features).last_hidden_state audio_hidden_states = audio_hidden_states.reshape( num_chunks, -1, self.config.audio_config.intermediate_size, ) audio_features = self.multi_modal_projector(audio_hidden_states) merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR) conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS) audio_output_lengths = _get_audio_output_lengths_for_tower( self.audio_tower, feature_attention_mask.sum(-1), merge_factor, conv_params, ) masked_audio_features = _flatten_audio_features_by_length( audio_features, audio_output_lengths ) chunk_embeddings = torch.split( masked_audio_features, audio_output_lengths.flatten().tolist() ) return _group_audio_embeddings(chunk_embeddings, chunk_counts) def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] masked_audio_features = self._process_audio_input(audio_input) return masked_audio_features def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = ["audio_tower.embed_positions"] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) @classmethod def _get_audio_token(cls, model_config: ModelConfig) -> str: """Get the audio token from processor. Similar to get_placeholder_str but returns single token. """ processor = cached_processor_from_config(model_config) return getattr(processor, "audio_token", "<|pad|>") @classmethod def get_speech_to_text_config( cls, model_config: ModelConfig, task_type: str ) -> SpeechToTextConfig: processor = cached_processor_from_config(model_config) feature_extractor = processor.feature_extractor max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) return SpeechToTextConfig( max_audio_clip_s=max_audio_clip_s, sample_rate=feature_extractor.sampling_rate, ) @classmethod def get_generation_prompt( cls, audio: np.ndarray, model_config: ModelConfig, stt_config: SpeechToTextConfig, language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, to_language: str | None, ) -> PromptType: """Get the generation prompt to be used for transcription requests.""" tokenizer = cached_tokenizer_from_config(model_config) audio_token = cls._get_audio_token(model_config) if task_type == "translate": full_lang_name_to = cls.supported_languages.get(to_language, to_language) user_content = f"{audio_token}translate the speech to {full_lang_name_to}" elif task_type == "transcribe": user_content = ( f"{audio_token}can you transcribe the speech into a written format?" ) else: raise ValueError(f"Unsupported task type {task_type}") messages = [{"role": "user", "content": user_content}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) prompt_token_ids = tokenizer.encode(prompt) prompt_dict = { "prompt_token_ids": prompt_token_ids, "multi_modal_data": {"audio": audio}, } return cast(PromptType, prompt_dict)