Commit 79052e70 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-qwen2audio' into 'v0.15.1-dev'

invoke flash_attn in the Qwen2AudioEncoder (transformers)

See merge request dcutoolkit/deeplearing/vllm!508
parents 9ce8b1a3 3a45ab97
......@@ -337,6 +337,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
self.quant_config = quant_config
with self._mark_tower_model(vllm_config, "audio"):
config.audio_config._attn_implementation = "flash_attention_2"
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
......@@ -422,6 +423,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
attn_impl = getattr(self.audio_tower.config, "_attn_implementation", "eager")
if attn_impl in ("flash_attention_2", "flash_attention_3"):
audio_attention_mask = (~padding_mask).to(dtype=torch.int32)
audio_outputs = self.audio_tower(
input_features, attention_mask=audio_attention_mask
)
......@@ -473,4 +477,4 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
return loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment