Unverified Commit 1c37e8c1 authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Add `sdpa` and FA2 for CLIP (#31940)



* Squashed commit of the following:

commit 102842cd477219b9f9bcb23a0bca3a8b92bd732f
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Fri Jul 12 18:23:52 2024 +0000

    Add model-specific sdpa tests

commit 60e4c88581abf89ec098da84ed8e92aa904c997d
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Fri Jul 12 18:20:53 2024 +0000

    Add fallback to eager (expensive operation)

commit c29033d30e7ffde4327e8a15cbbc6bee37546f80
Author: Pavel Iakubovskii <qubvel@gmail.com>
Date:   Thu Jul 11 17:09:55 2024 +0000

    Fix attn_implementation propagation

commit 783aed05f0f38cb2f99e758f81db6838ac55b9f8
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:05:27 2024 +0530

    style

commit e77e703ca75d00447cda277eca6b886cd32bddc0
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:04:57 2024 +0530

    add comment to explain why I had to touch forbidden codebase.

commit ab9d8849758e7773a31778ccba71588d18552623
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 09:03:02 2024 +0530

    fix: flax attribute access.

commit c570fc0abf9d1bd58c291aae3c7e384f995996d2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 08:23:54 2024 +0530

    fix tensorflow attribute name.

commit 32c812871cfdb268d8a6e3e2c61c5c925c8ed47e
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 07:57:10 2024 +0530

    fix attribute access.

commit 4f41a0138b6c417aed9c9332278f8bcd979cb7c2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Sat May 25 07:44:02 2024 +0530

    _from_config.

commit 35aed64ff602422adcf41d7f677a0a24bd9eccae
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 18:46:52 2024 +0530

    propagation of attn_implementation.

commit 4c25c19845438b1dc1d35a5adf9436151c8c5940
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:24:36 2024 +0530

    style again

commit 5f7dc5c5015c0f8116408f737e8c318d1802c80c
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:19:05 2024 +0530

    use from_config.

commit b70c409956d0359fa6ae5372275d2a20ba7e3389
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 24 09:13:43 2024 +0530

    quality

commit a7b63beff53d0fc754c6564e2a7b51731ddee49d
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 14:35:10 2024 +0200

    add benchmark numbers

commit 455b0eaea50862b8458c8f422b60fe60ae40fdcb
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:50:16 2024 +0200

    Revert "reflect feedback more"

    This reverts commit dc123e71eff60aae74d5f325f113d515d0d71117.

commit ca674829d28787349c2a9593a14e0f1d41f04ea4
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:50:05 2024 +0200

    Revert "fix"

    This reverts commit 37a1cb35b87acdc4cf7528b8b1ed6da27d244e52.

commit fab2dd8576c099eb1a3464958cb206a664d28247
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:47:46 2024 +0200

    fix

commit fbc6ae50fd6f2d36294d31e191761631b701d696
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 10 13:38:30 2024 +0200

    reflect feedback more

commit 87245bb020b2d60a89afe318a951df0159404fc9
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 08:54:34 2024 +0530

    fixes

commit 1057cc26390ee839251e7f8b3326c4207595fb23
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:49:03 2024 +0530

    don't explicit set attn_implementation in tests

commit e33f75916fc8a99f516b1cf449dbbe9d3aabda81
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:43:54 2024 +0530

    explicitly override attn_implementation in the towers.

commit 4cf41cb1bc885c39df7cb8f2a0694ebf23299235
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:38:42 2024 +0530

    import in one-line.

commit f2cc447ae9e74ccfacb448140cdf88259d4afc8c
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri May 3 07:34:58 2024 +0530

    move sdpa mention to usage tips.

commit 92884766c64dbb456926a3a84dd427be1349fa95
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 29 10:58:26 2024 +0530

    fix: memory allocation problem.

commit d7ffbbfe12f7750b7d0a361420f35c13e0ea787d
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 29 09:56:59 2024 +0530

    fix-copies

commit 8dfc3731cedd02e36acd3fe56bb2e6d61efd25d8
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Fri Apr 26 20:16:12 2024 +0530

    address arthur's comments.

commit d2ed7b4ce4ff15ae9aa4d3d0500f1544e3dcd9e9
Author: Sayak Paul <spsayakpaul@gmail.com>
Date:   Fri Apr 26 20:08:15 2024 +0530

    Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

commit 46e04361f37ded5c522ff05e9f725b9f82dce40e
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:55:27 2024 +0530

    add to docs.

commit 831629158ad40d34d8983f209afb2740ba041af2
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:33:10 2024 +0530

    styling.g

commit d263a119c77314250f4b4c8469caf42559197f22
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Wed Apr 24 09:15:20 2024 +0530

    up

commit d44f9d3d7633d4c241a737a1bc317f791f6aedb3
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 18:40:42 2024 +0530

    handle causal and attention mask

commit 122f1d60153df6666b634a94e38d073f3f260926
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 15:18:21 2024 +0530

    test fixes.

commit 4382d8cff6fa1dee5dbcf0d06b3e2841231e36f5
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 09:39:25 2024 +0530

    fix: scaling inside sdpa.

commit 0f629989efc48b7315cf19405a81e02955efe7e5
Author: Sayak Paul <spsayakpaul@gmail.com>
Date:   Tue Apr 23 08:14:58 2024 +0530

    Update src/transformers/models/clip/modeling_clip.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

commit 14367316877dc27ea40f767ad1aee38bbc97e4ce
Author: sayakpaul <spsayakpaul@gmail.com>
Date:   Mon Apr 22 16:21:36 2024 +0530

    add: sdpa support to clip.

* Remove fallback for empty attention mask (expensive operation)

* Fix typing in copies

* Add flash attention

* Add flash attention tests

* List CLIP in FA docs

* Fix embeddings attributes and tf

* [run-slow] clip

* Update clip documentation

* Remove commented code, skip compile dynamic for CLIPModel

* Fix doc

* Fix doc 2

* Remove double transpose

* Add torch version check for contiguous()

* Add comment to test mixin

* Fix copies

* Add comment for mask

* Update docs

* [run-slow] clip
parent b31d5950
...@@ -79,6 +79,123 @@ encode the text and prepare the images. The following example shows how to get t ...@@ -79,6 +79,123 @@ encode the text and prepare the images. The following example shows how to get t
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
``` ```
### Combining CLIP and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
<Tip warning={true}>
For small batch sizes, you might notice a slowdown in your model when using flash attention. Refer to the section [Expected speedups with Flash Attention and SDPA](#Expected-speedups-with-Flash-Attention-and-SDPA) below and select an appropriate attention implementation.
</Tip>
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> import requests
>>> from PIL import Image
>>> from transformers import CLIPProcessor, CLIPModel
>>> device = "cuda"
>>> torch_dtype = torch.float16
>>> model = CLIPModel.from_pretrained(
... "openai/clip-vit-base-patch32",
... attn_implementation="flash_attention_2",
... device_map=device,
... torch_dtype=torch_dtype,
... )
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
>>> inputs.to(device)
>>> with torch.no_grad():
... with torch.autocast(device):
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
>>> print(probs)
tensor([[0.9946, 0.0052]], device='cuda:0', dtype=torch.float16)
```
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```python
from transformers import CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16, attn_implementation="sdpa")
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
### Expected speedups with Flash Attention and SDPA
On a local benchmark (NVIDIA A10G, PyTorch 2.3.1+cu121) with `float16`, we saw the following speedups during inference for `"openai/clip-vit-large-patch14"` checkpoint ([code](https://gist.github.com/qubvel/ac691a54e54f9fae8144275f866a7ff8)):
#### CLIPTextModel
| Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 4 | 0.009 | 0.012 | 0.737 | 0.007 | 1.269 |
| 16 | 0.009 | 0.014 | 0.659 | 0.008 | 1.187 |
| 32 | 0.018 | 0.021 | 0.862 | 0.016 | 1.142 |
| 64 | 0.034 | 0.034 | 1.001 | 0.03 | 1.163 |
| 128 | 0.063 | 0.058 | 1.09 | 0.054 | 1.174 |
![clip_text_model_viz_3](https://github.com/user-attachments/assets/e9826b43-4e66-4f4c-952b-af4d90bd38eb)
#### CLIPVisionModel
| Image batch size | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|-------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 1 | 0.016 | 0.013 | 1.247 | 0.012 | 1.318 |
| 4 | 0.025 | 0.021 | 1.198 | 0.021 | 1.202 |
| 16 | 0.093 | 0.075 | 1.234 | 0.075 | 1.24 |
| 32 | 0.181 | 0.147 | 1.237 | 0.146 | 1.241 |
![clip_image_model_viz_3](https://github.com/user-attachments/assets/50a36206-e3b9-4adc-ac8e-926b8b071d63)
#### CLIPModel
| Image batch size | Num text labels | Eager (s/iter) | FA2 (s/iter) | FA2 speedup | SDPA (s/iter) | SDPA speedup |
|-------------------:|------------------:|-----------------:|---------------:|--------------:|----------------:|---------------:|
| 1 | 4 | 0.025 | 0.026 | 0.954 | 0.02 | 1.217 |
| 1 | 16 | 0.026 | 0.028 | 0.918 | 0.02 | 1.287 |
| 1 | 64 | 0.042 | 0.046 | 0.906 | 0.036 | 1.167 |
| 4 | 4 | 0.028 | 0.033 | 0.849 | 0.024 | 1.189 |
| 4 | 16 | 0.034 | 0.035 | 0.955 | 0.029 | 1.169 |
| 4 | 64 | 0.059 | 0.055 | 1.072 | 0.05 | 1.179 |
| 16 | 4 | 0.096 | 0.088 | 1.091 | 0.078 | 1.234 |
| 16 | 16 | 0.102 | 0.09 | 1.129 | 0.083 | 1.224 |
| 16 | 64 | 0.127 | 0.11 | 1.157 | 0.105 | 1.218 |
| 32 | 4 | 0.185 | 0.159 | 1.157 | 0.149 | 1.238 |
| 32 | 16 | 0.19 | 0.162 | 1.177 | 0.154 | 1.233 |
| 32 | 64 | 0.216 | 0.181 | 1.19 | 0.176 | 1.228 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with CLIP.
......
...@@ -40,6 +40,7 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -40,6 +40,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
...@@ -200,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -200,6 +201,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
......
...@@ -749,7 +749,7 @@ class AltCLIPAttention(nn.Module): ...@@ -749,7 +749,7 @@ class AltCLIPAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -838,7 +838,6 @@ class AltCLIPMLP(nn.Module): ...@@ -838,7 +838,6 @@ class AltCLIPMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->AltCLIP
class AltCLIPEncoderLayer(nn.Module): class AltCLIPEncoderLayer(nn.Module):
def __init__(self, config: AltCLIPConfig): def __init__(self, config: AltCLIPConfig):
super().__init__() super().__init__()
...@@ -889,7 +888,6 @@ class AltCLIPEncoderLayer(nn.Module): ...@@ -889,7 +888,6 @@ class AltCLIPEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->AltCLIP
class AltCLIPEncoder(nn.Module): class AltCLIPEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
...@@ -1080,7 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel): ...@@ -1080,7 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
class AltCLIPVisionTransformer(nn.Module): class AltCLIPVisionTransformer(nn.Module):
def __init__(self, config: AltCLIPVisionConfig): def __init__(self, config: AltCLIPVisionConfig):
super().__init__() super().__init__()
......
...@@ -26,17 +26,24 @@ from ...activations import ACT2FN ...@@ -26,17 +26,24 @@ from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# General docstring # General docstring
...@@ -254,7 +261,7 @@ class CLIPAttention(nn.Module): ...@@ -254,7 +261,7 @@ class CLIPAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -327,6 +334,173 @@ class CLIPAttention(nn.Module): ...@@ -327,6 +334,173 @@ class CLIPAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
class CLIPFlashAttention2(CLIPAttention):
"""
CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class CLIPSdpaAttention(CLIPAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from CLIPAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
CLIP_ATTENTION_CLASSES = {
"eager": CLIPAttention,
"sdpa": CLIPSdpaAttention,
"flash_attention_2": CLIPFlashAttention2,
}
class CLIPMLP(nn.Module): class CLIPMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -346,7 +520,7 @@ class CLIPEncoderLayer(nn.Module): ...@@ -346,7 +520,7 @@ class CLIPEncoderLayer(nn.Module):
def __init__(self, config: CLIPConfig): def __init__(self, config: CLIPConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = CLIPAttention(config) self.self_attn = CLIP_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config) self.mlp = CLIPMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -401,6 +575,8 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -401,6 +575,8 @@ class CLIPPreTrainedModel(PreTrainedModel):
config_class = CLIPConfig config_class = CLIPConfig
base_model_prefix = "clip" base_model_prefix = "clip"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -668,6 +844,9 @@ class CLIPTextTransformer(nn.Module): ...@@ -668,6 +844,9 @@ class CLIPTextTransformer(nn.Module):
# For `pooled_output` computation # For `pooled_output` computation
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
# For attention mask, it differs between `flash_attention_2` and other attention implementations
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
...@@ -702,8 +881,9 @@ class CLIPTextTransformer(nn.Module): ...@@ -702,8 +881,9 @@ class CLIPTextTransformer(nn.Module):
causal_attention_mask = _create_4d_causal_attention_mask( causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device input_shape, hidden_states.dtype, device=hidden_states.device
) )
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None and not self._use_flash_attention_2:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
...@@ -957,8 +1137,11 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -957,8 +1137,11 @@ class CLIPModel(CLIPPreTrainedModel):
self.text_embed_dim = text_config.hidden_size self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size self.vision_embed_dim = vision_config.hidden_size
self.text_model = CLIPTextTransformer(text_config) text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
self.vision_model = CLIPVisionTransformer(vision_config) self.text_model = text_model.text_model
vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
self.vision_model = vision_model.vision_model
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
...@@ -1173,7 +1356,8 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel): ...@@ -1173,7 +1356,8 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPTextConfig): def __init__(self, config: CLIPTextConfig):
super().__init__(config) super().__init__(config)
self.text_model = CLIPTextTransformer(config) text_model = CLIPTextModel._from_config(config, attn_implementation=config._attn_implementation)
self.text_model = text_model.text_model
self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
...@@ -1253,7 +1437,8 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel): ...@@ -1253,7 +1437,8 @@ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
def __init__(self, config: CLIPVisionConfig): def __init__(self, config: CLIPVisionConfig):
super().__init__(config) super().__init__(config)
self.vision_model = CLIPVisionTransformer(config) vision_model = CLIPVisionModel._from_config(config, attn_implementation=config._attn_implementation)
self.vision_model = vision_model.vision_model
self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False) self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
...@@ -1332,7 +1517,10 @@ class CLIPForImageClassification(CLIPPreTrainedModel): ...@@ -1332,7 +1517,10 @@ class CLIPForImageClassification(CLIPPreTrainedModel):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.vision_model = CLIPVisionTransformer(config.vision_config) vision_model = CLIPVisionModel._from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.vision_model = vision_model.vision_model
# Classifier head # Classifier head
self.classifier = ( self.classifier = (
......
...@@ -266,7 +266,7 @@ class CLIPSegAttention(nn.Module): ...@@ -266,7 +266,7 @@ class CLIPSegAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -355,7 +355,7 @@ class CLIPSegMLP(nn.Module): ...@@ -355,7 +355,7 @@ class CLIPSegMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->CLIPSeg # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
class CLIPSegEncoderLayer(nn.Module): class CLIPSegEncoderLayer(nn.Module):
def __init__(self, config: CLIPSegConfig): def __init__(self, config: CLIPSegConfig):
super().__init__() super().__init__()
...@@ -554,7 +554,7 @@ CLIPSEG_INPUTS_DOCSTRING = r""" ...@@ -554,7 +554,7 @@ CLIPSEG_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->CLIPSeg # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
class CLIPSegEncoder(nn.Module): class CLIPSegEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
...@@ -653,7 +653,6 @@ class CLIPSegEncoder(nn.Module): ...@@ -653,7 +653,6 @@ class CLIPSegEncoder(nn.Module):
class CLIPSegTextTransformer(nn.Module): class CLIPSegTextTransformer(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.__init__ with CLIP->CLIPSeg
def __init__(self, config: CLIPSegTextConfig): def __init__(self, config: CLIPSegTextConfig):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -667,7 +666,7 @@ class CLIPSegTextTransformer(nn.Module): ...@@ -667,7 +666,7 @@ class CLIPSegTextTransformer(nn.Module):
@add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg # Adapted from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
...@@ -806,7 +805,7 @@ class CLIPSegTextModel(CLIPSegPreTrainedModel): ...@@ -806,7 +805,7 @@ class CLIPSegTextModel(CLIPSegPreTrainedModel):
class CLIPSegVisionTransformer(nn.Module): class CLIPSegVisionTransformer(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIP->CLIPSeg # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
def __init__(self, config: CLIPSegVisionConfig): def __init__(self, config: CLIPSegVisionConfig):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -1149,7 +1148,7 @@ class CLIPSegDecoderLayer(nn.Module): ...@@ -1149,7 +1148,7 @@ class CLIPSegDecoderLayer(nn.Module):
self-attention/MLP, rather than before. self-attention/MLP, rather than before.
""" """
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer.__init__ with CLIP->CLIPSeg # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
def __init__(self, config: CLIPSegConfig): def __init__(self, config: CLIPSegConfig):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
......
...@@ -632,7 +632,7 @@ class GitVisionMLP(nn.Module): ...@@ -632,7 +632,7 @@ class GitVisionMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPAttention # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->GitVision
class GitVisionAttention(nn.Module): class GitVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -664,7 +664,7 @@ class GitVisionAttention(nn.Module): ...@@ -664,7 +664,7 @@ class GitVisionAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -737,7 +737,7 @@ class GitVisionAttention(nn.Module): ...@@ -737,7 +737,7 @@ class GitVisionAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
class GitVisionEncoderLayer(nn.Module): class GitVisionEncoderLayer(nn.Module):
def __init__(self, config: GitVisionConfig): def __init__(self, config: GitVisionConfig):
super().__init__() super().__init__()
...@@ -788,7 +788,7 @@ class GitVisionEncoderLayer(nn.Module): ...@@ -788,7 +788,7 @@ class GitVisionEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
class GitVisionEncoder(nn.Module): class GitVisionEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
...@@ -903,7 +903,7 @@ GIT_VISION_INPUTS_DOCSTRING = r""" ...@@ -903,7 +903,7 @@ GIT_VISION_INPUTS_DOCSTRING = r"""
class GitVisionTransformer(nn.Module): class GitVisionTransformer(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
def __init__(self, config: GitVisionConfig): def __init__(self, config: GitVisionConfig):
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -688,7 +688,7 @@ class GroupViTAttention(nn.Module): ...@@ -688,7 +688,7 @@ class GroupViTAttention(nn.Module):
return attn_output, attn_weights_reshaped return attn_output, attn_weights_reshaped
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GroupViT # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
class GroupViTEncoderLayer(nn.Module): class GroupViTEncoderLayer(nn.Module):
def __init__(self, config: GroupViTConfig): def __init__(self, config: GroupViTConfig):
super().__init__() super().__init__()
...@@ -1034,7 +1034,6 @@ class GroupViTTextEncoder(nn.Module): ...@@ -1034,7 +1034,6 @@ class GroupViTTextEncoder(nn.Module):
) )
# Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder, CLIP_TEXT->GROUPVIT_TEXT
class GroupViTTextTransformer(nn.Module): class GroupViTTextTransformer(nn.Module):
def __init__(self, config: GroupViTTextConfig): def __init__(self, config: GroupViTTextConfig):
super().__init__() super().__init__()
...@@ -1081,6 +1080,7 @@ class GroupViTTextTransformer(nn.Module): ...@@ -1081,6 +1080,7 @@ class GroupViTTextTransformer(nn.Module):
causal_attention_mask = _create_4d_causal_attention_mask( causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device input_shape, hidden_states.dtype, device=hidden_states.device
) )
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -192,7 +192,7 @@ class IdeficsVisionAttention(nn.Module): ...@@ -192,7 +192,7 @@ class IdeficsVisionAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -281,7 +281,7 @@ class IdeficsVisionMLP(nn.Module): ...@@ -281,7 +281,7 @@ class IdeficsVisionMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
class IdeficsVisionEncoderLayer(nn.Module): class IdeficsVisionEncoderLayer(nn.Module):
def __init__(self, config: IdeficsVisionConfig): def __init__(self, config: IdeficsVisionConfig):
super().__init__() super().__init__()
...@@ -332,7 +332,7 @@ class IdeficsVisionEncoderLayer(nn.Module): ...@@ -332,7 +332,7 @@ class IdeficsVisionEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->IdeficsVision
class IdeficsVisionEncoder(nn.Module): class IdeficsVisionEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
......
...@@ -444,7 +444,7 @@ class Kosmos2VisionAttention(nn.Module): ...@@ -444,7 +444,7 @@ class Kosmos2VisionAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -533,7 +533,7 @@ class Kosmos2VisionMLP(nn.Module): ...@@ -533,7 +533,7 @@ class Kosmos2VisionMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Kosmos2Vision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
class Kosmos2VisionEncoderLayer(nn.Module): class Kosmos2VisionEncoderLayer(nn.Module):
def __init__(self, config: Kosmos2VisionConfig): def __init__(self, config: Kosmos2VisionConfig):
super().__init__() super().__init__()
...@@ -584,7 +584,7 @@ class Kosmos2VisionEncoderLayer(nn.Module): ...@@ -584,7 +584,7 @@ class Kosmos2VisionEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Kosmos2Vision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Kosmos2Vision
class Kosmos2VisionEncoder(nn.Module): class Kosmos2VisionEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
...@@ -684,7 +684,7 @@ class Kosmos2VisionEncoder(nn.Module): ...@@ -684,7 +684,7 @@ class Kosmos2VisionEncoder(nn.Module):
# Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward` # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward`
class Kosmos2VisionTransformer(nn.Module): class Kosmos2VisionTransformer(nn.Module):
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPVision->Kosmos2Vision,CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2Vision # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision
def __init__(self, config: Kosmos2VisionConfig): def __init__(self, config: Kosmos2VisionConfig):
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -459,7 +459,7 @@ class Owlv2MLP(nn.Module): ...@@ -459,7 +459,7 @@ class Owlv2MLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Owlv2 # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Owlv2
class Owlv2EncoderLayer(nn.Module): class Owlv2EncoderLayer(nn.Module):
def __init__(self, config: Owlv2Config): def __init__(self, config: Owlv2Config):
super().__init__() super().__init__()
......
...@@ -451,7 +451,7 @@ class OwlViTMLP(nn.Module): ...@@ -451,7 +451,7 @@ class OwlViTMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT
class OwlViTEncoderLayer(nn.Module): class OwlViTEncoderLayer(nn.Module):
def __init__(self, config: OwlViTConfig): def __init__(self, config: OwlViTConfig):
super().__init__() super().__init__()
......
...@@ -829,7 +829,7 @@ SIGLIP_INPUTS_DOCSTRING = r""" ...@@ -829,7 +829,7 @@ SIGLIP_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
class SiglipEncoder(nn.Module): class SiglipEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
......
...@@ -199,7 +199,7 @@ class XCLIPAttention(nn.Module): ...@@ -199,7 +199,7 @@ class XCLIPAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
...@@ -288,7 +288,7 @@ class XCLIPMLP(nn.Module): ...@@ -288,7 +288,7 @@ class XCLIPMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->XCLIP # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->XCLIP
class XCLIPEncoderLayer(nn.Module): class XCLIPEncoderLayer(nn.Module):
def __init__(self, config: XCLIPConfig): def __init__(self, config: XCLIPConfig):
super().__init__() super().__init__()
...@@ -609,7 +609,7 @@ X_CLIP_INPUTS_DOCSTRING = r""" ...@@ -609,7 +609,7 @@ X_CLIP_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->XCLIP # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->XCLIP
class XCLIPEncoder(nn.Module): class XCLIPEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
......
...@@ -18,21 +18,33 @@ import inspect ...@@ -18,21 +18,33 @@ import inspect
import os import os
import tempfile import tempfile
import unittest import unittest
from typing import Optional, Tuple
import numpy as np import numpy as np
import requests import requests
from parameterized import parameterized
from pytest import mark
import transformers import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flax_available, is_flax_available,
is_pt_flax_cross_test, is_pt_flax_cross_test,
require_flash_attn,
require_torch, require_torch,
require_torch_gpu,
require_torch_sdpa,
require_vision, require_vision,
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import (
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_torch_sdpa_available,
is_vision_available,
)
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
...@@ -40,6 +52,7 @@ from ...test_modeling_common import ( ...@@ -40,6 +52,7 @@ from ...test_modeling_common import (
_config_zero_init, _config_zero_init,
floats_tensor, floats_tensor,
ids_tensor, ids_tensor,
is_flaky,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
...@@ -59,6 +72,10 @@ if is_torch_available(): ...@@ -59,6 +72,10 @@ if is_torch_available():
) )
if is_torch_sdpa_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -167,8 +184,180 @@ class CLIPVisionModelTester: ...@@ -167,8 +184,180 @@ class CLIPVisionModelTester:
return config, inputs_dict return config, inputs_dict
class CLIPModelTesterMixin(ModelTesterMixin):
"""
Subclass of ModelTesterMixin with methods specific to testing CLIP models.
The SDPA equivalence test is overridden here because CLIP models may have test/vision/text+vision inputs,
different output logits, and are not supposed to be used or tested with padding_side="left".
"""
def test_eager_matches_sdpa_inference(
self,
torch_dtype: str,
use_attention_mask_options: Tuple[Optional[str], ...] = (None, "left", "right"),
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Convert to torch dtype
dtypes = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtypes[torch_dtype]
atols = {
torch.float32: 1e-5,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
rtols = {
torch.float32: 1e-4,
torch.bfloat16: 3e-2,
torch.float16: 5e-3,
}
atol = atols[torch_dtype]
rtol = rtols[torch_dtype]
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
# but it would be nicer to have an efficient way to use parameterized.expand
cases = [
(use_mask, output_attentions, sdpa_backend, batch_size)
for use_mask in use_attention_mask_options
for output_attentions in [True, False]
for sdpa_backend in [
[SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
]
for batch_size in [1, 5]
]
fail_cases = []
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
processed_inputs = inputs_dict.copy()
# convert to torch_dtype
if "pixel_values" in processed_inputs:
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
# slice for different batch sizes
for key in ["pixel_values", "input_ids", "attention_mask"]:
if key in processed_inputs:
processed_inputs[key] = processed_inputs[key][:batch_size]
# set attention mask with left padding
if not use_mask:
processed_inputs.pop("attention_mask", None)
elif use_mask == "left":
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, :1] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
elif use_mask == "right":
dummy_attention_mask = processed_inputs["attention_mask"]
dummy_attention_mask[:] = 1
dummy_attention_mask[:, -1:] = 0
processed_inputs["attention_mask"] = dummy_attention_mask
else:
raise ValueError(f"Invalid value for use_mask={use_mask}")
processed_inputs["output_attentions"] = output_attentions
processed_inputs["output_hidden_states"] = True
current_case = f"use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
with torch.no_grad():
try:
with sdpa_kernel(sdpa_backend):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
except Exception as e:
fail_cases.append(f"{current_case}: {e}")
continue
keys = set(logit_keys) & set(outputs_eager.keys())
self.assertTrue(
keys, f"Keys {logit_keys} not found in outputs. Available keys: {outputs_eager.keys()}"
)
for key in keys:
try:
eager_logits = outputs_eager[key]
sdpa_logits = outputs_sdpa[key]
except KeyError:
raise KeyError(f"Key {key} not found in outputs. Available keys: {outputs_eager.keys()}")
if "hidden_state" in key and use_mask == "left":
eager_logits = eager_logits[:, 1:]
sdpa_logits = sdpa_logits[:, 1:]
elif "hidden_state" in key and use_mask == "right":
eager_logits = eager_logits[:, :-1]
sdpa_logits = sdpa_logits[:, :-1]
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
if not is_close:
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_torch @require_torch
class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
""" """
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds, Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
attention_mask and seq_length. attention_mask and seq_length.
...@@ -261,6 +450,17 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -261,6 +450,17 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "visual_projection")) self.assertTrue(hasattr(model, "visual_projection"))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("last_hidden_state", "pooler_output", "image_embeds"),
use_attention_mask_options=(None,),
)
class CLIPTextModelTester: class CLIPTextModelTester:
def __init__( def __init__(
...@@ -361,7 +561,7 @@ class CLIPTextModelTester: ...@@ -361,7 +561,7 @@ class CLIPTextModelTester:
@require_torch @require_torch
class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase):
all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else () all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
...@@ -428,6 +628,21 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -428,6 +628,21 @@ class CLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertTrue(hasattr(model, "text_projection")) self.assertTrue(hasattr(model, "text_projection"))
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("last_hidden_state", "pooler_output", "text_embeds"),
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIPTextModel has two attention masks: `causal_attention_mask` and `attention_mask`")
class CLIPModelTester: class CLIPModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
...@@ -479,7 +694,7 @@ class CLIPModelTester: ...@@ -479,7 +694,7 @@ class CLIPModelTester:
@require_torch @require_torch
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else () all_model_classes = (CLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {} {"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
...@@ -746,6 +961,115 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -746,6 +961,115 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = CLIPModel.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits_per_image", "logits_per_text"),
use_attention_mask_options=(None, "right"), # "left" is not supported for text model
)
@require_torch_sdpa
def test_sdpa_can_dispatch_on_flash(self):
self.skipTest(reason="CLIP text tower has two attention masks: `causal_attention_mask` and `attention_mask`")
@require_torch_sdpa
def test_sdpa_can_compile_dynamic(self):
self.skipTest(reason="CLIP model can't be compiled dynamic, error in clip_loss`")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
self.assertTrue(
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
)
self.assertTrue(
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
)
model.to(torch_device)
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
dummy_input_ids = inputs_dict["input_ids"]
dummy_pixel_mask = inputs_dict["attention_mask"]
# right padding
dummy_pixel_mask[:] = 1
dummy_pixel_mask[:, -1:] = 0
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
outputs_fa = model_fa(
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
)
logits_per_image_eager = outputs.logits_per_image[:, :-1]
logits_per_text_eager = outputs.logits_per_text[:, :-1]
logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
self.assertTrue(
torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
)
self.assertTrue(
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
class CLIPForImageClassificationModelTester(CLIPModelTester): class CLIPForImageClassificationModelTester(CLIPModelTester):
def __init__(self, parent): def __init__(self, parent):
...@@ -769,7 +1093,7 @@ class CLIPForImageClassificationModelTester(CLIPModelTester): ...@@ -769,7 +1093,7 @@ class CLIPForImageClassificationModelTester(CLIPModelTester):
@require_torch @require_torch
class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else () all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {} pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
...@@ -805,6 +1129,17 @@ class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, ...@@ -805,6 +1129,17 @@ class CLIPForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin,
def test_initialization(self): def test_initialization(self):
pass pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@is_flaky()
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
super().test_eager_matches_sdpa_inference(
torch_dtype=torch_dtype,
logit_keys=("logits",),
use_attention_mask_options=(None,),
)
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment