Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b8b16475
Unverified
Commit
b8b16475
authored
Feb 20, 2024
by
JB (Don)
Committed by
GitHub
Feb 20, 2024
Browse files
[Phi] Add support for sdpa (#29108)
parent
7688d8df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
1 deletion
+129
-1
docs/source/en/perf_infer_gpu_one.md
docs/source/en/perf_infer_gpu_one.md
+1
-0
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+128
-1
No files found.
docs/source/en/perf_infer_gpu_one.md
View file @
b8b16475
...
@@ -172,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc
...
@@ -172,6 +172,7 @@ For now, Transformers supports SDPA inference and training for the following arc
*
[
GPTBigCode
](
https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel
)
*
[
GPTBigCode
](
https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel
)
*
[
Falcon
](
https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel
)
*
[
Falcon
](
https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel
)
*
[
Llama
](
https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel
)
*
[
Llama
](
https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel
)
*
[
Phi
](
https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel
)
*
[
Idefics
](
https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel
)
*
[
Idefics
](
https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel
)
*
[
Whisper
](
https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel
)
*
[
Whisper
](
https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel
)
*
[
Mistral
](
https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel
)
*
[
Mistral
](
https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel
)
...
...
src/transformers/models/phi/modeling_phi.py
View file @
b8b16475
...
@@ -22,12 +22,16 @@ from typing import List, Optional, Tuple, Union
...
@@ -22,12 +22,16 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
packaging
import
version
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
,
DynamicCache
from
...cache_utils
import
Cache
,
DynamicCache
from
...modeling_attn_mask_utils
import
_prepare_4d_causal_attention_mask
from
...modeling_attn_mask_utils
import
(
_prepare_4d_causal_attention_mask
,
_prepare_4d_causal_attention_mask_for_sdpa
,
)
from
...modeling_outputs
import
(
from
...modeling_outputs
import
(
BaseModelOutputWithPast
,
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
CausalLMOutputWithPast
,
...
@@ -39,6 +43,7 @@ from ...utils import (
...
@@ -39,6 +43,7 @@ from ...utils import (
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
add_start_docstrings_to_model_forward
,
get_torch_version
,
is_flash_attn_2_available
,
is_flash_attn_2_available
,
is_flash_attn_greater_or_equal_2_10
,
is_flash_attn_greater_or_equal_2_10
,
logging
,
logging
,
...
@@ -617,9 +622,121 @@ class PhiFlashAttention2(PhiAttention):
...
@@ -617,9 +622,121 @@ class PhiFlashAttention2(PhiAttention):
)
)
class
PhiSdpaAttention
(
PhiAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
require_contiguous_qkv
=
version
.
parse
(
get_torch_version
())
<
version
.
parse
(
"2.2.0"
)
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from PhiAttention.forward
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Cache
]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
if
output_attentions
:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger
.
warning_once
(
"PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return
super
().
forward
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
if
self
.
qk_layernorm
:
query_states
=
self
.
q_layernorm
(
query_states
)
key_states
=
self
.
k_layernorm
(
key_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
self
.
layer_idx
is
None
:
raise
ValueError
(
f
"The cache structure has changed since version v4.36. If you are using
{
self
.
__class__
.
__name__
}
"
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_usable_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
query_rot
,
query_pass
=
(
query_states
[...,
:
self
.
rotary_emb
.
dim
],
query_states
[...,
self
.
rotary_emb
.
dim
:],
)
key_rot
,
key_pass
=
(
key_states
[...,
:
self
.
rotary_emb
.
dim
],
key_states
[...,
self
.
rotary_emb
.
dim
:],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
position_ids
)
# [batch_size, seq_length, num_heads, head_dim]
query_states
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_states
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
if
past_key_value
is
not
None
:
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"partial_rotation_size"
:
self
.
rotary_emb
.
dim
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if
self
.
require_contiguous_qkv
and
query_states
.
device
.
type
==
"cuda"
and
attention_mask
is
not
None
:
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
attention_mask
,
dropout_p
=
self
.
attention_dropout
if
self
.
training
else
0.0
,
is_causal
=
self
.
is_causal
and
attention_mask
is
None
and
q_len
>
1
,
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
dense
(
attn_output
)
return
attn_output
,
None
,
past_key_value
PHI_ATTENTION_CLASSES
=
{
PHI_ATTENTION_CLASSES
=
{
"eager"
:
PhiAttention
,
"eager"
:
PhiAttention
,
"flash_attention_2"
:
PhiFlashAttention2
,
"flash_attention_2"
:
PhiFlashAttention2
,
"sdpa"
:
PhiSdpaAttention
,
}
}
...
@@ -714,6 +831,7 @@ class PhiPreTrainedModel(PreTrainedModel):
...
@@ -714,6 +831,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"PhiDecoderLayer"
]
_no_split_modules
=
[
"PhiDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
@@ -821,7 +939,9 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -821,7 +939,9 @@ class PhiModel(PhiPreTrainedModel):
[
PhiDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
[
PhiDecoderLayer
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)]
)
)
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
_use_flash_attention_2
=
config
.
_attn_implementation
==
"flash_attention_2"
self
.
_use_flash_attention_2
=
config
.
_attn_implementation
==
"flash_attention_2"
self
.
_use_sdpa
=
config
.
_attn_implementation
==
"sdpa"
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
@@ -895,6 +1015,13 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -895,6 +1015,13 @@ class PhiModel(PhiPreTrainedModel):
if
self
.
_use_flash_attention_2
:
if
self
.
_use_flash_attention_2
:
# 2d mask is passed through the layers
# 2d mask is passed through the layers
attention_mask
=
attention_mask
if
(
attention_mask
is
not
None
and
0
in
attention_mask
)
else
None
attention_mask
=
attention_mask
if
(
attention_mask
is
not
None
and
0
in
attention_mask
)
else
None
elif
self
.
_use_sdpa
and
not
output_attentions
:
attention_mask
=
_prepare_4d_causal_attention_mask_for_sdpa
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
,
)
else
:
else
:
# 4d mask is passed through the layers
# 4d mask is passed through the layers
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
=
_prepare_4d_causal_attention_mask
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment