Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
609a1767
Unverified
Commit
609a1767
authored
Feb 15, 2024
by
Arthur
Committed by
GitHub
Feb 15, 2024
Browse files
[`CLeanup`] Revert SDPA attention changes that got in the static kv cache PR (#29027)
* revert unrelated changes that got in * style
parent
7a0fccc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
48 deletions
+33
-48
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+11
-16
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+11
-16
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
+11
-16
No files found.
src/transformers/models/mistral/modeling_mistral.py
View file @
609a1767
...
...
@@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
):
# user defined causal mask
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
# this one liner is equivalent to the pad_unpad function
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
...
@@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention):
query_states
,
key_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
609a1767
...
...
@@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
):
# user defined causal mask
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
# this one liner is equivalent to the pad_unpad function
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
...
@@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention):
query_states
,
key_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/qwen2/modeling_qwen2.py
View file @
609a1767
...
...
@@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
):
# user defined causal mask
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
# this one liner is equivalent to the pad_unpad function
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
...
@@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
query_states
,
key_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment