Commit 79052e70 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-qwen2audio' into 'v0.15.1-dev'

invoke flash_attn in the Qwen2AudioEncoder (transformers)

See merge request dcutoolkit/deeplearing/vllm!508
parents 9ce8b1a3 3a45ab97
...@@ -337,6 +337,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -337,6 +337,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
self.quant_config = quant_config self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "audio"): with self._mark_tower_model(vllm_config, "audio"):
config.audio_config._attn_implementation = "flash_attention_2"
self.audio_tower = Qwen2AudioEncoder(config.audio_config) self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector( self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size config.audio_config.d_model, config.text_config.hidden_size
...@@ -422,6 +423,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -422,6 +423,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
) )
audio_attention_mask[audio_attention_mask_] = float("-inf") audio_attention_mask[audio_attention_mask_] = float("-inf")
attn_impl = getattr(self.audio_tower.config, "_attn_implementation", "eager")
if attn_impl in ("flash_attention_2", "flash_attention_3"):
audio_attention_mask = (~padding_mask).to(dtype=torch.int32)
audio_outputs = self.audio_tower( audio_outputs = self.audio_tower(
input_features, attention_mask=audio_attention_mask input_features, attention_mask=audio_attention_mask
) )
...@@ -473,4 +477,4 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports ...@@ -473,4 +477,4 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment