Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
388fd314
Unverified
Commit
388fd314
authored
Dec 14, 2023
by
Joao Gante
Committed by
GitHub
Dec 14, 2023
Browse files
Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window (#28037)
parent
0ede7626
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
10 deletions
+28
-10
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+14
-5
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+14
-5
No files found.
src/transformers/models/mistral/modeling_mistral.py
View file @
388fd314
...
@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
...
@@ -363,6 +363,12 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# Because the input can be padded, the absolute sequence length depends on the max position id.
# Because the input can be padded, the absolute sequence length depends on the max position id.
...
@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
...
@@ -385,11 +391,16 @@ class MistralFlashAttention2(MistralAttention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if
getattr
(
self
.
config
,
"sliding_window"
,
None
)
is
not
None
and
kv_seq_len
>
self
.
config
.
sliding_window
:
cache_has_contents
=
past_key_value
.
get_seq_length
(
self
.
layer_idx
)
>
0
if
(
getattr
(
self
.
config
,
"sliding_window"
,
None
)
is
not
None
and
kv_seq_len
>
self
.
config
.
sliding_window
and
cache_has_contents
):
slicing_tokens
=
1
-
self
.
config
.
sliding_window
slicing_tokens
=
1
-
self
.
config
.
sliding_window
past_key
=
past_key_value
[
0
]
past_key
=
past_key_value
[
self
.
layer_idx
][
0
]
past_value
=
past_key_value
[
1
]
past_value
=
past_key_value
[
self
.
layer_idx
][
1
]
past_key
=
past_key
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_key
=
past_key
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_value
=
past_value
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_value
=
past_value
[:,
:,
slicing_tokens
:,
:].
contiguous
()
...
@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
...
@@ -400,8 +411,6 @@ class MistralFlashAttention2(MistralAttention):
f
"
{
past_key
.
shape
}
"
f
"
{
past_key
.
shape
}
"
)
)
past_key_value
=
(
past_key
,
past_value
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
slicing_tokens
:]
attention_mask
=
attention_mask
[:,
slicing_tokens
:]
attention_mask
=
torch
.
cat
([
attention_mask
,
torch
.
ones_like
(
attention_mask
[:,
-
1
:])],
dim
=-
1
)
attention_mask
=
torch
.
cat
([
attention_mask
,
torch
.
ones_like
(
attention_mask
[:,
-
1
:])],
dim
=-
1
)
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
388fd314
...
@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
...
@@ -414,6 +414,12 @@ class MixtralFlashAttention2(MixtralAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# Because the input can be padded, the absolute sequence length depends on the max position id.
# Because the input can be padded, the absolute sequence length depends on the max position id.
...
@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
...
@@ -436,11 +442,16 @@ class MixtralFlashAttention2(MixtralAttention):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if
getattr
(
self
.
config
,
"sliding_window"
,
None
)
is
not
None
and
kv_seq_len
>
self
.
config
.
sliding_window
:
cache_has_contents
=
past_key_value
.
get_seq_length
(
self
.
layer_idx
)
>
0
if
(
getattr
(
self
.
config
,
"sliding_window"
,
None
)
is
not
None
and
kv_seq_len
>
self
.
config
.
sliding_window
and
cache_has_contents
):
slicing_tokens
=
1
-
self
.
config
.
sliding_window
slicing_tokens
=
1
-
self
.
config
.
sliding_window
past_key
=
past_key_value
[
0
]
past_key
=
past_key_value
[
self
.
layer_idx
][
0
]
past_value
=
past_key_value
[
1
]
past_value
=
past_key_value
[
self
.
layer_idx
][
1
]
past_key
=
past_key
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_key
=
past_key
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_value
=
past_value
[:,
:,
slicing_tokens
:,
:].
contiguous
()
past_value
=
past_value
[:,
:,
slicing_tokens
:,
:].
contiguous
()
...
@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
...
@@ -451,8 +462,6 @@ class MixtralFlashAttention2(MixtralAttention):
f
"
{
past_key
.
shape
}
"
f
"
{
past_key
.
shape
}
"
)
)
past_key_value
=
(
past_key
,
past_value
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
slicing_tokens
:]
attention_mask
=
attention_mask
[:,
slicing_tokens
:]
attention_mask
=
torch
.
cat
([
attention_mask
,
torch
.
ones_like
(
attention_mask
[:,
-
1
:])],
dim
=-
1
)
attention_mask
=
torch
.
cat
([
attention_mask
,
torch
.
ones_like
(
attention_mask
[:,
-
1
:])],
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment