Unverified Commit b275a410 authored by Anton Vlasjuk's avatar Anton Vlasjuk Committed by GitHub
Browse files

[`GPT2`] Add SDPA support (#31172)

* `gpt2` sdpa support

* fix (at least) one test, style, repo consistency

* fix sdpa mask in forward --> fixes generation

* test

* test2

* test3

* test4

* simplify shapes for attn mask creation and small comments

* hub fail test

* benchmarks

* flash attn 2 mask should not be inverted on enc-dec setup

* fix comment

* apply some suggestion from code review

- only save _attn_implentation once
- remove unnecessary comment

* change elif logic

* [run-slow] gpt2

* modify `test_gpt2_sample_max_time` to follow previous assertion patterns
parent 22b41b3f
...@@ -127,6 +127,64 @@ Below is an expected speedup diagram that compares pure inference time between t ...@@ -127,6 +127,64 @@ Below is an expected speedup diagram that compares pure inference time between t
<img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg"> <img src="https://huggingface.co/datasets/EduardoPacheco/documentation-images/resolve/main/gpt2_flash_attention_2_speedup.jpg">
</div> </div>
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (rtx3080ti-16GB, PyTorch 2.2.1, OS Ubuntu 22.04) using `float16` with
[gpt2-large](https://huggingface.co/openai-community/gpt2-large), we saw the
following speedups during training and inference.
### Training
| Batch size | Seq len | Time per batch (Eager - s) | Time per batch (SDPA - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) |
|-----------:|--------:|----------------------------:|--------------------------:|------------:|--------------------:|-------------------:|------------------:|
| 1 | 128 | 0.039 | 0.032 | 23.042 | 3482.32 | 3494.62 | -0.352 |
| 1 | 256 | 0.073 | 0.059 | 25.15 | 3546.66 | 3552.6 | -0.167 |
| 1 | 512 | 0.155 | 0.118 | 30.96 | 4230.1 | 3665.59 | 15.4 |
| 1 | 1024 | 0.316 | 0.209 | 50.839 | 8682.26 | 4881.09 | 77.875 |
| 2 | 128 | 0.07 | 0.06 | 15.324 | 3557.8 | 3545.91 | 0.335 |
| 2 | 256 | 0.143 | 0.122 | 16.53 | 3901.5 | 3657.68 | 6.666 |
| 2 | 512 | 0.267 | 0.213 | 25.626 | 7062.21 | 4876.47 | 44.822 |
| 2 | 1024 | OOM | 0.404 | / | OOM | 8096.35 | SDPA does not OOM |
| 4 | 128 | 0.134 | 0.128 | 4.412 | 3675.79 | 3648.72 | 0.742 |
| 4 | 256 | 0.243 | 0.217 | 12.292 | 6129.76 | 4871.12 | 25.839 |
| 4 | 512 | 0.494 | 0.406 | 21.687 | 12466.6 | 8102.64 | 53.858 |
| 4 | 1024 | OOM | 0.795 | / | OOM | 14568.2 | SDPA does not OOM |
### Inference
| Batch size | Seq len | Per token latency Eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem Eager (MB) | Mem SDPA (MB) | Mem saved (%) |
|-----------:|--------:|-----------------------------:|----------------------------:|------------:|---------------:|--------------:|--------------:|
| 1 | 128 | 7.991 | 6.968 | 14.681 | 1685.2 | 1701.32 | -0.947 |
| 1 | 256 | 8.462 | 7.199 | 17.536 | 1745.49 | 1770.78 | -1.428 |
| 1 | 512 | 8.68 | 7.853 | 10.529 | 1907.69 | 1921.29 | -0.708 |
| 1 | 768 | 9.101 | 8.365 | 8.791 | 2032.93 | 2068.12 | -1.701 |
| 2 | 128 | 9.169 | 9.001 | 1.861 | 1803.84 | 1811.4 | -0.418 |
| 2 | 256 | 9.907 | 9.78 | 1.294 | 1907.72 | 1921.44 | -0.714 |
| 2 | 512 | 11.519 | 11.644 | -1.071 | 2176.86 | 2197.75 | -0.951 |
| 2 | 768 | 13.022 | 13.407 | -2.873 | 2464.3 | 2491.06 | -1.074 |
| 4 | 128 | 10.097 | 9.831 | 2.709 | 1942.25 | 1985.13 | -2.16 |
| 4 | 256 | 11.599 | 11.398 | 1.764 | 2177.28 | 2197.86 | -0.937 |
| 4 | 512 | 14.653 | 14.45 | 1.411 | 2753.16 | 2772.57 | -0.7 |
| 4 | 768 | 17.846 | 17.617 | 1.299 | 3327.04 | 3343.97 | -0.506 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with GPT2. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
......
...@@ -201,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -201,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
......
...@@ -24,10 +24,12 @@ from typing import Optional, Tuple, Union ...@@ -24,10 +24,12 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -42,6 +44,7 @@ from ...utils import ( ...@@ -42,6 +44,7 @@ from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available, is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10, is_flash_attn_greater_or_equal_2_10,
logging, logging,
...@@ -557,6 +560,113 @@ class GPT2FlashAttention2(GPT2Attention): ...@@ -557,6 +560,113 @@ class GPT2FlashAttention2(GPT2Attention):
) )
class GPT2SdpaAttention(GPT2Attention):
"""
GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
to adapt to the SDPA API.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
bsz, q_len, _ = hidden_states.size()
# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = None
if use_cache is True:
present = (key, value)
# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)
# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output, present, None
class GPT2MLP(nn.Module): class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config): def __init__(self, intermediate_size, config):
super().__init__() super().__init__()
...@@ -574,10 +684,7 @@ class GPT2MLP(nn.Module): ...@@ -574,10 +684,7 @@ class GPT2MLP(nn.Module):
return hidden_states return hidden_states
GPT2_ATTENTION_CLASSES = { GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
"eager": GPT2Attention,
"flash_attention_2": GPT2FlashAttention2,
}
class GPT2Block(nn.Module): class GPT2Block(nn.Module):
...@@ -673,6 +780,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -673,6 +780,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPT2Block"] _no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
...@@ -1021,11 +1129,24 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -1021,11 +1129,24 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
# Attention mask. # Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
if self._attn_implementation == "flash_attention_2": if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None attention_mask = attention_mask if 0 in attention_mask else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
)
else: else:
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] # Sizes are [batch_size, 1, 1, to_seq_length]
...@@ -1049,7 +1170,11 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -1049,7 +1170,11 @@ class GPT2Model(GPT2PreTrainedModel):
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if self._attn_implementation != "flash_attention_2": if _use_sdpa:
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
)
elif not self._attn_implementation == "flash_attention_2":
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
encoder_attention_mask = None encoder_attention_mask = None
...@@ -1060,11 +1185,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -1060,11 +1185,6 @@ class GPT2Model(GPT2PreTrainedModel):
# head_mask has shape n_layer x batch x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None: if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids) token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds hidden_states = hidden_states + token_type_embeds
......
...@@ -832,7 +832,8 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -832,7 +832,8 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
start = datetime.datetime.now() start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256) model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
@slow @slow
def test_contrastive_search_gpt2(self): def test_contrastive_search_gpt2(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment