Commit 7768afba authored by Zian(Andy) Zheng's avatar Zian(Andy) Zheng
Browse files

Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598
parent 611a5a80
...@@ -65,6 +65,7 @@ def attention_forward( ...@@ -65,6 +65,7 @@ def attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
""" """
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment