Commit 3a45ab97 authored by caihl's avatar caihl
Browse files

invoke flash_attn in the Qwen2AudioEncoder (transformers)

parent 9ce8b1a3
...@@ -337,6 +337,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -337,6 +337,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
self.quant_config = quant_config self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "audio"): with self._mark_tower_model(vllm_config, "audio"):
config.audio_config._attn_implementation = "flash_attention_2"
self.audio_tower = Qwen2AudioEncoder(config.audio_config) self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector( self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size config.audio_config.d_model, config.text_config.hidden_size
...@@ -422,6 +423,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -422,6 +423,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
) )
audio_attention_mask[audio_attention_mask_] = float("-inf") audio_attention_mask[audio_attention_mask_] = float("-inf")
attn_impl = getattr(self.audio_tower.config, "_attn_implementation", "eager")
if attn_impl in ("flash_attention_2", "flash_attention_3"):
audio_attention_mask = (~padding_mask).to(dtype=torch.int32)
audio_outputs = self.audio_tower( audio_outputs = self.audio_tower(
input_features, attention_mask=audio_attention_mask input_features, attention_mask=audio_attention_mask
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment