"docs/source/ko/tasks/token_classification.mdx" did not exist on "4f4e5ddbcbdcd9d6353fc27d0137ac887a7f2f25"
Unverified Commit 0142aab7 authored by Wing Lian's avatar Wing Lian Committed by GitHub
Browse files

don't zero out the attention_mask when using sliding window with flash attention (#31670)

* don't zero out the attention_mask when using sliding window with flash attention

* chore: lint
parent 1c68f2ca
...@@ -602,6 +602,7 @@ GEMMA2_ATTENTION_CLASSES = { ...@@ -602,6 +602,7 @@ GEMMA2_ATTENTION_CLASSES = {
class Gemma2DecoderLayer(nn.Module): class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int): def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
...@@ -625,7 +626,9 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -625,7 +626,9 @@ class Gemma2DecoderLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril( attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window torch.ones_like(attention_mask), diagonal=-self.sliding_window
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment