# SPDX-License-Identifier: Apache-2.0 from transformers import Qwen2Config from transformers.configuration_utils import PretrainedConfig from vllm.transformers_utils.configs.step import Step1Config class Step1fAudioEncoderConfig(PretrainedConfig): model_type = "stepasr_encoder" def __init__( self, n_mels: int = 128, n_audio_ctx: int = 1500, n_audio_state: int = 1280, n_audio_head: int = 20, n_audio_layer: int = 32, n_codebook_size: int = 4096, llm_dim: int = 3072, kernel_size: int = 3, adapter_stride: int = 2, adapter_state: int = 2048, **kwargs, ) -> None: super().__init__(**kwargs) self.n_mels = n_mels self.n_audio_ctx = n_audio_ctx self.n_audio_state = n_audio_state self.n_audio_head = n_audio_head self.n_audio_layer = n_audio_layer self.n_codebook_size = n_codebook_size self.llm_dim = llm_dim self.kernel_size = kernel_size self.adapter_stride = adapter_stride self.adapter_state = adapter_state class Step1AudioConfig(PretrainedConfig): # for step1.5t model_type = "step1_audio" def __init__( self, hidden_size: int = 5120, intermediate_size: int = 13312, num_attention_heads: int = 40, num_attention_groups: int = 8, num_hidden_layers: int = 48, max_seq_len: int = 4096, vocab_size: int = 65536, rms_norm_eps: float = 1e-5, audio_token_id: int = 29, eos_token_id=None, audio_encoder_config=None, **kwargs, ) -> None: if eos_token_id is not None: if isinstance(eos_token_id, list): eos_token_id = list(set([2, 3] + eos_token_id)) else: eos_token_id = [2, 3, eos_token_id] else: eos_token_id = [2, 3] super().__init__( eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_attention_groups = num_attention_groups self.max_seq_len = max_seq_len self.rms_norm_eps = rms_norm_eps self.audio_token_id = audio_token_id self.audio_encoder_config = Step1fAudioEncoderConfig( **audio_encoder_config) if audio_encoder_config is not None else None self.text_config = Step1Config( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_attention_groups=num_attention_groups, num_hidden_layers=num_hidden_layers, max_seq_len=max_seq_len, vocab_size=vocab_size, rms_norm_eps=rms_norm_eps, architectures=["Step1ForCausalLM"], torch_dtype=getattr(self, "torch_dtype", "bfloat16"), ) class StepAudioQwen2Config(PretrainedConfig): model_type = "step_audio_qwen2" def __init__( self, vocab_size=64012, hidden_size=4096, intermediate_size=11008, num_hidden_layers=48, num_attention_heads=32, num_attention_groups=4, num_key_value_heads=4, hidden_act="silu", max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, rope_theta=1000000.0, rope_scaling=None, audio_token_id=151690, eos_token_id=None, audio_encoder_config=None, **kwargs ): if eos_token_id is not None: if isinstance(eos_token_id, list): eos_token_id = list(set([151643, 151645, 151665] + eos_token_id)) else: eos_token_id = [151643, 151645, 151665, eos_token_id] else: eos_token_id = [151643, 151645, 151665] super().__init__( eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_attention_groups = num_attention_groups self.num_key_value_heads = num_key_value_heads assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads" self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.audio_encoder_config = Step1fAudioEncoderConfig( **audio_encoder_config) if audio_encoder_config is not None else None self.audio_token_id = audio_token_id self.text_config = Qwen2Config( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, initializer_range=initializer_range, rms_norm_eps=rms_norm_eps, rope_theta=rope_theta, rope_scaling=rope_scaling, architectures=["Qwen2ForCausalLM"], torch_dtype=getattr(self, "torch_dtype", "bfloat16"), )