Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
609a1767
Unverified
Commit
609a1767
authored
Feb 15, 2024
by
Arthur
Committed by
GitHub
Feb 15, 2024
Browse files
[`CLeanup`] Revert SDPA attention changes that got in the static kv cache PR (#29027)
* revert unrelated changes that got in * style
parent
7a0fccc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
48 deletions
+33
-48
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+11
-16
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+11
-16
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
+11
-16
No files found.
src/transformers/models/mistral/modeling_mistral.py
View file @
609a1767
...
@@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
if
attention_mask
is
not
None
:
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
):
# user defined causal mask
raise
ValueError
(
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
# this one liner is equivalent to the pad_unpad function
)
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
@@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention):
query_states
,
query_states
,
key_states
,
key_states
,
value_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
609a1767
...
@@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention):
...
@@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
if
attention_mask
is
not
None
:
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
):
# user defined causal mask
raise
ValueError
(
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
# this one liner is equivalent to the pad_unpad function
)
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
@@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention):
...
@@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention):
query_states
,
query_states
,
key_states
,
key_states
,
value_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
src/transformers/models/qwen2/modeling_qwen2.py
View file @
609a1767
...
@@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention):
...
@@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
# add what was seen
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
past_seen_tokens
=
kv_seq_len
-
key_states
.
shape
[
-
2
]
new_cache_positions
=
torch
.
arange
(
past_seen_tokens
,
past_seen_tokens
+
q_len
,
device
=
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"position_ids"
:
new_cache_positions
}
# Specific to RoPE models
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
}
# Specific to RoPE models
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
(
if
attention_mask
is
not
None
:
attention_mask
is
not
None
and
not
torch
.
all
(
attention_mask
[...,
0
]
==
1
)
and
q_len
!=
1
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
):
# user defined causal mask
raise
ValueError
(
causal_mask
=
attention_mask
[:,
:,
past_seen_tokens
:
past_seen_tokens
+
q_len
,
:
key_states
.
shape
[
-
2
]]
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
# this one liner is equivalent to the pad_unpad function
)
causal_mask
.
mul_
(
~
torch
.
eq
(
causal_mask
,
causal_mask
.
min
()).
all
(
dim
=-
1
)[...,
None
])
else
:
causal_mask
=
None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if
query_states
.
device
.
type
==
"cuda"
and
causal
_mask
is
not
None
:
if
query_states
.
device
.
type
==
"cuda"
and
attention
_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
...
@@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
...
@@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
query_states
,
query_states
,
key_states
,
key_states
,
value_states
,
value_states
,
attn_mask
=
causal
_mask
,
attn_mask
=
attention
_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
causal_mask
is
None
and
q_len
>
1
,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment