Unverified Commit 37555cf4 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix: max_past default value must be -1, not 0 (#1348)

parent 9b78a6ee
...@@ -149,7 +149,7 @@ class MistralAttention(torch.nn.Module): ...@@ -149,7 +149,7 @@ class MistralAttention(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else 0 config.sliding_window if config.sliding_window is not None else -1
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
......
...@@ -204,7 +204,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -204,7 +204,7 @@ class MixtralAttention(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else 0 config.sliding_window if config.sliding_window is not None else -1
) )
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
......
...@@ -72,6 +72,9 @@ def attention( ...@@ -72,6 +72,9 @@ def attention(
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
): ):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA: if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment