Unverified Commit 1c21f48a authored by hyenal's avatar hyenal Committed by GitHub
Browse files

add sdpa to ViT [follow up of #29325] (#30555)



remove blank line (+1 squashed commit)
Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8db3] [run-slow]vit_msn
[ead817eca] fix vit msn multi gpu
[d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa88f] doc
[a3ff33e4a] finish implementation
[e20b7b7fb] Update test_modeling_common.py
[e290c5810] Update test_modeling_flax_common.py
[d3af86f46] comment
[ff7dd32d8] more comments
[59b137889] suggestion
[7e2ba6d67] attn_implementation as attribute of the class
[fe66ab71f] minor
[38642b568] Apply suggestions from code review

Accept comments
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[22cde7d52] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[48e137cc6] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[99f4c679f] Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
[00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00ebb0] all tests are passing locally
[e9e0b82b7] vision encoder/decoder
[4d5076b56] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e26] all tests are passing locally
[bff7fc366] minor
[62f88306f] fix yolo and text_encoder tests
[121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10dd] fix-copies
[ef6c511c4] test vit hybrid
[7d4ba8644] vit hybrid
[66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a031] fixes
[cfde6eb21] fixup
[e77df1ed3] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e22] vit + vit_mae are working
[547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97dfa9] it s the complete opposite...
[aefab37d4] fix more tests
[71802a1b9] fix all torch tests
[40b12eb58] encoder - decoder tests
[941552b69] slow decorator where appropriate
[14d055d80] has_attentions to yolo and msn
[3381fa19f] add correct name
[e261316a7] repo consistency
[31c6d0c08] fixup
[9d214276c] minor fix
[11ed2e1b7] chore
[eca6644c4] add sdpa to vit-based models
[cffbf390b] make fix-copies result
[6468319b0] fix style
[d324cd02a] add sdpa for vit
Co-authored-by: default avatarLiubov Yaronskaya <luba.yaronskaya@gmail.com>
parent 9fd606db
...@@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset. ...@@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset.
- Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the - Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the
[PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task. [PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ASTForAudioClassification
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MIT/ast-finetuned-audioset-10-10-0.4593` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 27 | 6 | 4.5 |
| 2 | 12 | 6 | 2 |
| 4 | 21 | 8 | 2.62 |
| 8 | 40 | 14 | 2.86 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer.
......
...@@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso ...@@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to *facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to
prepare images for the model. prepare images for the model.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import DeiTForImageClassification
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/deit-base-distilled-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 8 | 6 | 1.33 |
| 2 | 9 | 6 | 1.5 |
| 4 | 9 | 6 | 1.5 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT.
......
...@@ -33,6 +33,34 @@ alt="drawing" width="600"/> ...@@ -33,6 +33,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/MCG-NJU/VideoMAE). The original code can be found [here](https://github.com/MCG-NJU/VideoMAE).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import VideoMAEForVideoClassification
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MCG-NJU/videomae-base-finetuned-kinetics` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 37 | 10 | 3.7 |
| 2 | 24 | 18 | 1.33 |
| 4 | 43 | 32 | 1.34 |
| 8 | 84 | 60 | 1.4 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If
......
...@@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him! ...@@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him!
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training. improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer). Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer).
......
...@@ -39,6 +39,34 @@ substantially fewer computational resources to train.* ...@@ -39,6 +39,34 @@ substantially fewer computational resources to train.*
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer). found [here](https://github.com/google-research/vision_transformer).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTHybridForImageClassification
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 29 | 18 | 1.61 |
| 2 | 26 | 18 | 1.44 |
| 4 | 25 | 18 | 1.39 |
| 8 | 34 | 24 | 1.42 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid.
......
...@@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea ...@@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea
sin/cos position embeddings are added both to the input of the encoder and the decoder. sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/). - For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMAEModel
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-mae-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 11 | 6 | 1.83 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE.
......
...@@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS ...@@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS
- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K - MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K
labels when fine-tuned. labels when fine-tuned.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMSNForImageClassification
model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-msn-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN.
......
...@@ -32,6 +32,34 @@ alt="drawing" width="600"/> ...@@ -32,6 +32,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS). This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import AutoModelForObjectDetection
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `hustvl/yolos-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 106 | 76 | 1.39 |
| 2 | 154 | 90 | 1.71 |
| 4 | 222 | 116 | 1.91 |
| 8 | 368 | 168 | 2.19 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS.
......
...@@ -192,10 +192,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se ...@@ -192,10 +192,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
For now, Transformers supports SDPA inference and training for the following architectures: For now, Transformers supports SDPA inference and training for the following architectures:
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
...@@ -216,12 +218,18 @@ For now, Transformers supports SDPA inference and training for the following arc ...@@ -216,12 +218,18 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
<Tip> <Tip>
......
...@@ -169,6 +169,38 @@ class ASTSelfAttention(nn.Module): ...@@ -169,6 +169,38 @@ class ASTSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
class ASTSdpaSelfAttention(ASTSelfAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
class ASTSelfOutput(nn.Module): class ASTSelfOutput(nn.Module):
""" """
...@@ -228,6 +260,13 @@ class ASTAttention(nn.Module): ...@@ -228,6 +260,13 @@ class ASTAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
class ASTSdpaAttention(ASTAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention = ASTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module): class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None: def __init__(self, config: ASTConfig) -> None:
...@@ -261,7 +300,13 @@ class ASTOutput(nn.Module): ...@@ -261,7 +300,13 @@ class ASTOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST AST_ATTENTION_CLASSES = {
"eager": ASTAttention,
"sdpa": ASTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module): class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -269,7 +314,7 @@ class ASTLayer(nn.Module): ...@@ -269,7 +314,7 @@ class ASTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ASTAttention(config) self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ASTIntermediate(config) self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config) self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel): ...@@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel):
base_model_prefix = "audio_spectrogram_transformer" base_model_prefix = "audio_spectrogram_transformer"
main_input_name = "input_values" main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
......
...@@ -190,6 +190,38 @@ class DeiTSelfAttention(nn.Module): ...@@ -190,6 +190,38 @@ class DeiTSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
class DeiTSdpaSelfAttention(DeiTSelfAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
class DeiTSelfOutput(nn.Module): class DeiTSelfOutput(nn.Module):
""" """
...@@ -249,6 +281,13 @@ class DeiTAttention(nn.Module): ...@@ -249,6 +281,13 @@ class DeiTAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
class DeiTSdpaAttention(DeiTAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention = DeiTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module): class DeiTIntermediate(nn.Module):
def __init__(self, config: DeiTConfig) -> None: def __init__(self, config: DeiTConfig) -> None:
...@@ -282,7 +321,13 @@ class DeiTOutput(nn.Module): ...@@ -282,7 +321,13 @@ class DeiTOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT DEIT_ATTENTION_CLASSES = {
"eager": DeiTAttention,
"sdpa": DeiTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module): class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -290,7 +335,7 @@ class DeiTLayer(nn.Module): ...@@ -290,7 +335,7 @@ class DeiTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = DeiTAttention(config) self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = DeiTIntermediate(config) self.intermediate = DeiTIntermediate(config)
self.output = DeiTOutput(config) self.output = DeiTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DeiTLayer"] _no_split_modules = ["DeiTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -134,7 +134,6 @@ class VideoMAEEmbeddings(nn.Module): ...@@ -134,7 +134,6 @@ class VideoMAEEmbeddings(nn.Module):
# add position embeddings # add position embeddings
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach() embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
# only keep visible patches # only keep visible patches
# ~bool_masked_pos means visible # ~bool_masked_pos means visible
if bool_masked_pos is not None: if bool_masked_pos is not None:
...@@ -268,6 +267,40 @@ class VideoMAESelfAttention(nn.Module): ...@@ -268,6 +267,40 @@ class VideoMAESelfAttention(nn.Module):
return outputs return outputs
class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
key_layer = self.transpose_for_scores(keys)
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
class VideoMAESelfOutput(nn.Module): class VideoMAESelfOutput(nn.Module):
""" """
...@@ -327,6 +360,13 @@ class VideoMAEAttention(nn.Module): ...@@ -327,6 +360,13 @@ class VideoMAEAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
class VideoMAESdpaAttention(VideoMAEAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention = VideoMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
class VideoMAEIntermediate(nn.Module): class VideoMAEIntermediate(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None: def __init__(self, config: VideoMAEConfig) -> None:
...@@ -360,7 +400,10 @@ class VideoMAEOutput(nn.Module): ...@@ -360,7 +400,10 @@ class VideoMAEOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
class VideoMAELayer(nn.Module): class VideoMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -368,7 +411,7 @@ class VideoMAELayer(nn.Module): ...@@ -368,7 +411,7 @@ class VideoMAELayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = VideoMAEAttention(config) self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VideoMAEIntermediate(config) self.intermediate = VideoMAEIntermediate(config)
self.output = VideoMAEOutput(config) self.output = VideoMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel): ...@@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "videomae" base_model_prefix = "videomae"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -336,8 +336,20 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -336,8 +336,20 @@ class VisionEncoderDecoderModel(PreTrainedModel):
del tf_model del tf_model
gc.collect() gc.collect()
attn_implementation = kwargs.get("attn_implementation", None)
kwargs_encoder_decoder = {}
if attn_implementation:
kwargs_encoder_decoder = {
"encoder_attn_implementation": attn_implementation,
"decoder_attn_implementation": attn_implementation,
}
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True encoder_dir,
decoder_dir,
encoder_from_tf=True,
decoder_from_tf=True,
**kwargs_encoder_decoder,
) )
# This is only for copying some specific attributes of this particular model. # This is only for copying some specific attributes of this particular model.
model.config = config model.config = config
......
...@@ -236,6 +236,37 @@ class ViTSelfAttention(nn.Module): ...@@ -236,6 +236,37 @@ class ViTSelfAttention(nn.Module):
return outputs return outputs
class ViTSdpaSelfAttention(ViTSelfAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
class ViTSelfOutput(nn.Module): class ViTSelfOutput(nn.Module):
""" """
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
...@@ -293,6 +324,12 @@ class ViTAttention(nn.Module): ...@@ -293,6 +324,12 @@ class ViTAttention(nn.Module):
return outputs return outputs
class ViTSdpaAttention(ViTAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention = ViTSdpaSelfAttention(config)
class ViTIntermediate(nn.Module): class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None: def __init__(self, config: ViTConfig) -> None:
super().__init__() super().__init__()
...@@ -324,6 +361,12 @@ class ViTOutput(nn.Module): ...@@ -324,6 +361,12 @@ class ViTOutput(nn.Module):
return hidden_states return hidden_states
VIT_ATTENTION_CLASSES = {
"eager": ViTAttention,
"sdpa": ViTSdpaAttention,
}
class ViTLayer(nn.Module): class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -331,7 +374,7 @@ class ViTLayer(nn.Module): ...@@ -331,7 +374,7 @@ class ViTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTAttention(config) self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTIntermediate(config) self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config) self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"] _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -248,6 +248,38 @@ class ViTHybridSelfAttention(nn.Module): ...@@ -248,6 +248,38 @@ class ViTHybridSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid
class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid
class ViTHybridSelfOutput(nn.Module): class ViTHybridSelfOutput(nn.Module):
""" """
...@@ -307,6 +339,13 @@ class ViTHybridAttention(nn.Module): ...@@ -307,6 +339,13 @@ class ViTHybridAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid
class ViTHybridSdpaAttention(ViTHybridAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention = ViTHybridSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid
class ViTHybridIntermediate(nn.Module): class ViTHybridIntermediate(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None: def __init__(self, config: ViTHybridConfig) -> None:
...@@ -340,6 +379,12 @@ class ViTHybridOutput(nn.Module): ...@@ -340,6 +379,12 @@ class ViTHybridOutput(nn.Module):
return hidden_states return hidden_states
VIT_HYBRID_ATTENTION_CLASSES = {
"eager": ViTHybridAttention,
"sdpa": ViTHybridSdpaAttention,
}
class ViTHybridLayer(nn.Module): class ViTHybridLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -347,7 +392,7 @@ class ViTHybridLayer(nn.Module): ...@@ -347,7 +392,7 @@ class ViTHybridLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTHybridAttention(config) self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTHybridIntermediate(config) self.intermediate = ViTHybridIntermediate(config)
self.output = ViTHybridOutput(config) self.output = ViTHybridOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel): ...@@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -241,8 +241,8 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -241,8 +241,8 @@ class ViTMAEEmbeddings(nn.Module):
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample # sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
# keep the first subset # keep the first subset
ids_keep = ids_shuffle[:, :len_keep] ids_keep = ids_shuffle[:, :len_keep]
...@@ -370,6 +370,38 @@ class ViTMAESelfAttention(nn.Module): ...@@ -370,6 +370,38 @@ class ViTMAESelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module): class ViTMAESelfOutput(nn.Module):
""" """
...@@ -429,6 +461,13 @@ class ViTMAEAttention(nn.Module): ...@@ -429,6 +461,13 @@ class ViTMAEAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
class ViTMAESdpaAttention(ViTMAEAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention = ViTMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module): class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None: def __init__(self, config: ViTMAEConfig) -> None:
...@@ -462,7 +501,13 @@ class ViTMAEOutput(nn.Module): ...@@ -462,7 +501,13 @@ class ViTMAEOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE VITMAE_ATTENTION_CLASSES = {
"eager": ViTMAEAttention,
"sdpa": ViTMAESdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
class ViTMAELayer(nn.Module): class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -470,7 +515,7 @@ class ViTMAELayer(nn.Module): ...@@ -470,7 +515,7 @@ class ViTMAELayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTMAEAttention(config) self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMAEIntermediate(config) self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config) self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel): ...@@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -764,7 +810,8 @@ class ViTMAEDecoder(nn.Module): ...@@ -764,7 +810,8 @@ class ViTMAEDecoder(nn.Module):
# append mask tokens to sequence # append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed # add pos embed
......
...@@ -222,6 +222,38 @@ class ViTMSNSelfAttention(nn.Module): ...@@ -222,6 +222,38 @@ class ViTMSNSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN
class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
class ViTMSNSelfOutput(nn.Module): class ViTMSNSelfOutput(nn.Module):
""" """
...@@ -281,6 +313,13 @@ class ViTMSNAttention(nn.Module): ...@@ -281,6 +313,13 @@ class ViTMSNAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN
class ViTMSNSdpaAttention(ViTMSNAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention = ViTMSNSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
class ViTMSNIntermediate(nn.Module): class ViTMSNIntermediate(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None: def __init__(self, config: ViTMSNConfig) -> None:
...@@ -314,7 +353,10 @@ class ViTMSNOutput(nn.Module): ...@@ -314,7 +353,10 @@ class ViTMSNOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
class ViTMSNLayer(nn.Module): class ViTMSNLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -322,7 +364,7 @@ class ViTMSNLayer(nn.Module): ...@@ -322,7 +364,7 @@ class ViTMSNLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTMSNAttention(config) self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMSNIntermediate(config) self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config) self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -419,7 +461,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel): ...@@ -419,7 +461,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention"] _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
_supports_sdpa = True
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts. # when creating pre-training scripts.
......
...@@ -307,6 +307,38 @@ class YolosSelfAttention(nn.Module): ...@@ -307,6 +307,38 @@ class YolosSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos
class YolosSdpaSelfAttention(YolosSelfAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
class YolosSelfOutput(nn.Module): class YolosSelfOutput(nn.Module):
""" """
...@@ -366,6 +398,13 @@ class YolosAttention(nn.Module): ...@@ -366,6 +398,13 @@ class YolosAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos
class YolosSdpaAttention(YolosAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention = YolosSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
class YolosIntermediate(nn.Module): class YolosIntermediate(nn.Module):
def __init__(self, config: YolosConfig) -> None: def __init__(self, config: YolosConfig) -> None:
...@@ -399,7 +438,10 @@ class YolosOutput(nn.Module): ...@@ -399,7 +438,10 @@ class YolosOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
class YolosLayer(nn.Module): class YolosLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
...@@ -407,7 +449,7 @@ class YolosLayer(nn.Module): ...@@ -407,7 +449,7 @@ class YolosLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = YolosAttention(config) self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = YolosIntermediate(config) self.intermediate = YolosIntermediate(config)
self.output = YolosOutput(config) self.output = YolosOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -531,6 +573,7 @@ class YolosPreTrainedModel(PreTrainedModel): ...@@ -531,6 +573,7 @@ class YolosPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = [] _no_split_modules = []
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -63,6 +63,7 @@ class ASTModelTester: ...@@ -63,6 +63,7 @@ class ASTModelTester:
scope=None, scope=None,
frequency_stride=2, frequency_stride=2,
time_stride=2, time_stride=2,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -83,6 +84,7 @@ class ASTModelTester: ...@@ -83,6 +84,7 @@ class ASTModelTester:
self.scope = scope self.scope = scope
self.frequency_stride = frequency_stride self.frequency_stride = frequency_stride
self.time_stride = time_stride self.time_stride = time_stride
self.attn_implementation = attn_implementation
# in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) # in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1 frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1
...@@ -117,6 +119,7 @@ class ASTModelTester: ...@@ -117,6 +119,7 @@ class ASTModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
frequency_stride=self.frequency_stride, frequency_stride=self.frequency_stride,
time_stride=self.time_stride, time_stride=self.time_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, input_values, labels): def create_and_check_model(self, config, input_values, labels):
......
...@@ -80,6 +80,8 @@ class DeiTModelTester: ...@@ -80,6 +80,8 @@ class DeiTModelTester:
num_labels=3, num_labels=3,
scope=None, scope=None,
encoder_stride=2, encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -99,10 +101,14 @@ class DeiTModelTester: ...@@ -99,10 +101,14 @@ class DeiTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 2 self.seq_length = num_patches + 2
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -130,6 +136,7 @@ class DeiTModelTester: ...@@ -130,6 +136,7 @@ class DeiTModelTester:
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment