Unverified Commit 608fa549 authored by Jonatan Kłosko's avatar Jonatan Kłosko Committed by GitHub
Browse files

Make sliding window size inclusive in eager attention (#29519)

* Make sliding window size inclusive in eager attention

* Fix tests
parent f386c51a
...@@ -164,10 +164,10 @@ class AttentionMaskConverter: ...@@ -164,10 +164,10 @@ class AttentionMaskConverter:
# add lower triangular sliding window mask if necessary # add lower triangular sliding window mask if necessary
if sliding_window is not None: if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1 diagonal = past_key_values_length - sliding_window - 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) mask.masked_fill_(context_mask, torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
......
...@@ -1673,7 +1673,7 @@ class AttentionMaskTester(unittest.TestCase): ...@@ -1673,7 +1673,7 @@ class AttentionMaskTester(unittest.TestCase):
def compute_num_context_mask(self, kv_len, context, q_len): def compute_num_context_mask(self, kv_len, context, q_len):
# This function computes the # of attention tokens that are added for # This function computes the # of attention tokens that are added for
# the sliding window # the sliding window
c_mask_len = kv_len - context c_mask_len = kv_len - context - 1
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
cut_mask_len = max(c_mask_len - q_len, 0) cut_mask_len = max(c_mask_len - q_len, 0)
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2 num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment