Unverified Commit 3ee68590 authored by AllenDou's avatar AllenDou Committed by GitHub
Browse files

refactor funasr model. (#36108)


Signed-off-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 71963481
...@@ -51,7 +51,6 @@ from vllm.multimodal.processing import ( ...@@ -51,7 +51,6 @@ from vllm.multimodal.processing import (
) )
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -611,6 +610,10 @@ class FunASRAudioInputs(TensorSchema): ...@@ -611,6 +610,10 @@ class FunASRAudioInputs(TensorSchema):
list[torch.Tensor] | None, list[torch.Tensor] | None,
TensorShape("b"), TensorShape("b"),
] ]
fake_token_lengths: Annotated[
list[torch.Tensor] | None,
TensorShape("b"),
]
class FunASREncoder(nn.Module): class FunASREncoder(nn.Module):
...@@ -732,9 +735,6 @@ class FunASRProcessingInfo(BaseProcessingInfo): ...@@ -732,9 +735,6 @@ class FunASRProcessingInfo(BaseProcessingInfo):
def get_target_channels(self) -> int: def get_target_channels(self) -> int:
return 1 return 1
def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]): class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
...@@ -798,7 +798,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): ...@@ -798,7 +798,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
return dict( return dict(
input_features=MultiModalFieldConfig.batched("audio"), input_features=MultiModalFieldConfig.batched("audio"),
speech_lengths=MultiModalFieldConfig.batched("audio"), speech_lengths=MultiModalFieldConfig.batched("audio"),
fake_token_len=MultiModalFieldConfig.batched("audio"), fake_token_lengths=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -812,22 +812,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): ...@@ -812,22 +812,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
out_mm_data = out_mm_kwargs.get_data() out_mm_data = out_mm_kwargs.get_data()
fake_token_len = out_mm_data.get("fake_token_len") fake_token_lengths = out_mm_data.get("fake_token_lengths")
if fake_token_len is None: if fake_token_lengths is None:
audio_output_lengths = [] audio_output_lengths = []
else: else:
assert isinstance(fake_token_len, torch.Tensor) assert isinstance(fake_token_lengths, torch.Tensor)
audio_output_lengths = fake_token_len.tolist() audio_output_lengths = fake_token_lengths.tolist()
def get_replacement_qwen2_audio(item_idx: int): def get_replacement_qwen2_audio(item_idx: int):
if audio_output_lengths:
num_features = audio_output_lengths[item_idx] num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
return [audio_token_id] * num_features return [audio_token_id] * num_features
return [ return [
...@@ -847,21 +841,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): ...@@ -847,21 +841,16 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
class FunASRForConditionalGeneration( class FunASRForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal nn.Module, SupportsTranscription, SupportsMultiModal
): ):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
"linear_q.": "q_proj.", "linear_q.": "q_proj.",
"linear_k.": "k_proj.", "linear_k.": "k_proj.",
"linear_v.": "v_proj.", "linear_v.": "v_proj.",
"linear_out.": "out_proj.", "linear_out.": "out_proj.",
"audio_adaptor.": "model.encoder.audio_adaptor.",
"audio_encoder.": "model.encoder.audio_encoder.",
"llm.model.": "model.decoder.",
"llm.lm_head": "lm_head",
} }
) )
...@@ -969,9 +958,6 @@ class FunASRForConditionalGeneration( ...@@ -969,9 +958,6 @@ class FunASRForConditionalGeneration(
) )
return decoder_outputs return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
...@@ -1002,15 +988,12 @@ class FunASRForConditionalGeneration( ...@@ -1002,15 +988,12 @@ class FunASRForConditionalGeneration(
def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs: def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
input_features = kwargs.pop("input_features", None) input_features = kwargs.pop("input_features", None)
speech_lengths = kwargs.pop("speech_lengths", None) speech_lengths = kwargs.pop("speech_lengths", None)
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
if input_features is not None:
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
if speech_lengths is not None:
speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths)
return FunASRAudioInputs( return FunASRAudioInputs(
input_features=input_features, speech_lengths=speech_lengths input_features=input_features,
speech_lengths=speech_lengths,
fake_token_lengths=fake_token_lengths,
) )
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -1022,22 +1005,4 @@ class FunASRForConditionalGeneration( ...@@ -1022,22 +1005,4 @@ class FunASRForConditionalGeneration(
self, self,
) )
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
"""
for name, weight in weights:
if name.endswith(".k_proj.weight"):
bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)]
else:
yield name, weight
...@@ -1794,7 +1794,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1794,7 +1794,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return [] return []
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video). # tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = () multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary # NOTE: It is important to iterate over the keys in this dictionary
......
...@@ -370,7 +370,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ...@@ -370,7 +370,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
) )
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2 olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2 olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len = (olens - 1) // 2 + 1 fake_token_lengths = (olens - 1) // 2 + 1
if isinstance(input_features[0], list): if isinstance(input_features[0], list):
padded_inputs["input_features"] = [ padded_inputs["input_features"] = [
np.asarray(feature, dtype=np.float32) for feature in input_features np.asarray(feature, dtype=np.float32) for feature in input_features
...@@ -382,8 +382,10 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ...@@ -382,8 +382,10 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
fake_token_lengths = torch.clamp(fake_token_lengths, min=1)
padded_inputs["speech_lengths"] = speech_lengths padded_inputs["speech_lengths"] = speech_lengths
padded_inputs["fake_token_len"] = fake_token_len padded_inputs["fake_token_lengths"] = fake_token_lengths
return padded_inputs return padded_inputs
...@@ -471,7 +473,7 @@ class FunASRProcessor(ProcessorMixin): ...@@ -471,7 +473,7 @@ class FunASRProcessor(ProcessorMixin):
for sample in text: for sample in text:
replace_str = [] replace_str = []
while self.audio_token in sample: while self.audio_token in sample:
num_audio_tokens = inputs["fake_token_len"].item() num_audio_tokens = inputs["fake_token_lengths"].item()
expanded_audio_token = self.audio_token * num_audio_tokens expanded_audio_token = self.audio_token * num_audio_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment