Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8b16475
Unverified
Commit
b8b16475
authored
Feb 20, 2024
by
JB (Don)
Committed by
GitHub
Feb 20, 2024
Browse files
[Phi] Add support for sdpa (#29108)
parent
7688d8df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
1 deletion
+129
-1
docs/source/en/perf_infer_gpu_one.md
docs/source/en/perf_infer_gpu_one.md
+1
-0
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+128
-1
No files found.
docs/source/en/perf_infer_gpu_one.md
View file @
b8b16475
...
...
@@ -172,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc
*
[
GPTBigCode
](
https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel
)
*
[
Falcon
](
https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel
)
*
[
Llama
](
https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel
)
*
[
Phi
](
https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel
)
*
[
Idefics
](
https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel
)
*
[
Whisper
](
https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel
)
*
[
Mistral
](
https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel
)
...
...
src/transformers/models/phi/modeling_phi.py
View file @
b8b16475
...
...
@@ -22,12 +22,16 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
from
packaging
import
version
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
,
DynamicCache
from
...modeling_attn_mask_utils
import
_prepare_4d_causal_attention_mask
from
...modeling_attn_mask_utils
import
(
_prepare_4d_causal_attention_mask
,
_prepare_4d_causal_attention_mask_for_sdpa
,
)
from
...modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
...
...
@@ -39,6 +43,7 @@ from ...utils import (
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
get_torch_version
,
is_flash_attn_2_available
,
is_flash_attn_greater_or_equal_2_10
,
logging
,
...
...
@@ -617,9 +622,121 @@ class PhiFlashAttention2(PhiAttention):
)
class
PhiSdpaAttention
(
PhiAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
require_contiguous_qkv
=
version
.
parse
(
get_torch_version
())
<
version
.
parse
(
"2.2.0"
)
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from PhiAttention.forward
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
output_attentions
:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger
.
warning_once
(
"PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return
super
().
forward
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
if
self
.
qk_layernorm
:
query_states
=
self
.
q_layernorm
(
query_states
)
key_states
=
self
.
k_layernorm
(
key_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
query_rot
,
query_pass
=
(
query_states
[...,
:
self
.
rotary_emb
.
dim
],
query_states
[...,
self
.
rotary_emb
.
dim
:],
)
key_rot
,
key_pass
=
(
key_states
[...,
:
self
.
rotary_emb
.
dim
],
key_states
[...,
self
.
rotary_emb
.
dim
:],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
position_ids
)
# [batch_size, seq_length, num_heads, head_dim]
query_states
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_states
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"partial_rotation_size"
:
self
.
rotary_emb
.
dim
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if
self
.
require_contiguous_qkv
and
query_states
.
device
.
type
==
"cuda"
and
attention_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
attention_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
dense
(
attn_output
)
return
attn_output
,
None
,
past_key_value
PHI_ATTENTION_CLASSES
=
{
"eager"
:
PhiAttention
,
"flash_attention_2"
:
PhiFlashAttention2
,
"sdpa"
:
PhiSdpaAttention
,
}
...
...
@@ -714,6 +831,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"PhiDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
...
...
@@ -821,7 +939,9 @@ class PhiModel(PhiPreTrainedModel):
[
PhiDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
)
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
_use_flash_attention_2
=
config
.
_attn_implementation
==
"flash_attention_2"
self
.
_use_sdpa
=
config
.
_attn_implementation
==
"sdpa"
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
...
...
@@ -895,6 +1015,13 @@ class PhiModel(PhiPreTrainedModel):
if
self
.
_use_flash_attention_2
:
# 2d mask is passed through the layers
attention_mask
=
attention_mask
if
(
attention_mask
is
not
None
and
0
in
attention_mask
)
else
None
elif
self
.
_use_sdpa
and
not
output_attentions
:
attention_mask
=
_prepare_4d_causal_attention_mask_for_sdpa
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
,
)
else
:
# 4d mask is passed through the layers
attention_mask
=
_prepare_4d_causal_attention_mask
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment