Unverified Commit 3042c63a authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Add methods to PreTrainedModel to use PyTorch's BetterTransformer (#21259)



* fix mess

* better documentation

* typo

* fix doc

* update

* add test

* fix test

* more tests

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* move to utils

* Apply suggestions from code review
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* nit

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 0083b149
......@@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/acc
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install --no-cache-dir bitsandbytes
# For bettertransformer
RUN python3 -m pip install --no-cache-dir optimum
# For video model testing
RUN python3 -m pip install --no-cache-dir decord av==9.2.0
......
......@@ -11,11 +11,28 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
# Efficient Inference on a Single GPU
This document will be completed soon with information on how to infer on a single GPU. In the meantime you can check out [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
In addition to this guide, relevant information can be found as well in [the guide for training on a single GPU](perf_train_gpu_one) and [the guide for inference on CPUs](perf_infer_cpu).
## `BetterTransformer` for faster inference
## Better Transformer: PyTorch-native transformer fastpath
We have recently integrated `BetterTransformer` for faster inference on GPU for text, image and audio models. Check the documentation about this integration [here](https://huggingface.co/docs/optimum/bettertransformer/overview) for more details.
PyTorch-native [`nn.MultiHeadAttention`](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/) attention fastpath, called BetterTransformer, can be used with Transformers through the integration in the [🤗 Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview).
PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
After installing the [`optimum`](https://github.com/huggingface/optimum) package, to use Better Transformer during inference, the relevant internal modules are replaced by calling [`~PreTrainedModel.to_bettertransformer`]:
```python
model = model.to_bettertransformer()
```
The method [`~PreTrainedModel.reverse_bettertransformer`] allows to go back to the original modeling, which should be used before saving the model in order to use the canonical transformers modeling:
```python
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models).
## `bitsandbytes` integration for Int8 mixed-precision matrix decomposition
......
......@@ -718,6 +718,18 @@ For some applications, such as pretraining large language models, applying all t
Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many).
## Using PyTorch native attention
PyTorch 2.0 released the native [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA), that allows to use fused GPU kernels as [memory-efficient attention](https://arxiv.org/abs/2112.05682) and [flash attention](https://arxiv.org/abs/2205.14135).
After installing the [`optimum`](https://github.com/huggingface/optimum) package, the relevant internal modules can be replaced to use PyTorch's native attention with:
```python
model = model.to_bettertransformer()
```
Training can then be done as usual.
## Using torch.compile
PyTorch 2.0 introduces a new compile function, you can learn more about it [in their documentation](https://pytorch.org/get-started/pytorch-2.0/). It uses Pythons frame evaluation API to automatically create a graph from existing PyTorch programs. After capturing the graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
......
......@@ -64,6 +64,7 @@ from .utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_offline_mode,
is_optimum_available,
is_remote_url,
is_safetensors_available,
is_torch_tpu_available,
......@@ -3310,6 +3311,56 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cls._auto_class = auto_class
def to_bettertransformer(self) -> "PreTrainedModel":
"""
Converts the model to use [PyTorch's native attention
implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
subset of all Transformers models are supported.
PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
Returns:
[`PreTrainedModel`]: The model converted to BetterTransformer.
"""
if not is_optimum_available():
raise ImportError("The package `optimum` is required to use Better Transformer.")
from optimum.version import __version__ as optimum_version
if version.parse(optimum_version) < version.parse("1.7.0"):
raise ImportError(
f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
)
from optimum.bettertransformer import BetterTransformer
return BetterTransformer.transform(self)
def reverse_bettertransformer(self):
"""
Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
used, for example in order to save the model.
Returns:
[`PreTrainedModel`]: The model converted back to the original modeling.
"""
if not is_optimum_available():
raise ImportError("The package `optimum` is required to use Better Transformer.")
from optimum.version import __version__ as optimum_version
if version.parse(optimum_version) < version.parse("1.7.0"):
raise ImportError(
f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
)
from optimum.bettertransformer import BetterTransformer
return BetterTransformer.reverse(self)
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
......
......@@ -65,6 +65,7 @@ from .utils import (
is_librosa_available,
is_natten_available,
is_onnx_available,
is_optimum_available,
is_pandas_available,
is_phonemizer_available,
is_pyctcdecode_available,
......@@ -693,6 +694,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
def require_optimum(test_case):
"""
Decorator for optimum dependency
"""
return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
......
......@@ -121,6 +121,7 @@ from .import_utils import (
is_natten_available,
is_ninja_available,
is_onnx_available,
is_optimum_available,
is_pandas_available,
is_peft_available,
is_phonemizer_available,
......
# coding=utf-8
# Copyright 2023 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import (
is_torch_available,
require_optimum,
require_torch,
slow,
)
if is_torch_available():
import torch
@require_torch
@require_optimum
@slow
class BetterTransformerIntegrationTest(unittest.TestCase):
# refer to the full test suite in Optimum library:
# https://github.com/huggingface/optimum/tree/main/tests/bettertransformer
def test_transform_and_reverse(self):
r"""
Classic tests to simply check if the conversion has been successfull.
"""
model_id = "hf-internal-testing/tiny-random-t5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
inp = tokenizer("This is me", return_tensors="pt")
model = model.to_bettertransformer()
self.assertTrue(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))
output = model.generate(**inp)
model = model.reverse_bettertransformer()
self.assertFalse(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_reloaded = AutoModelForSeq2SeqLM.from_pretrained(tmpdirname)
self.assertFalse(
any("BetterTransformer" in mod.__class__.__name__ for _, mod in model_reloaded.named_modules())
)
output_from_pretrained = model_reloaded.generate(**inp)
self.assertTrue(torch.allclose(output, output_from_pretrained))
def test_error_save_pretrained(self):
r"""
The save_pretrained method should raise a ValueError if the model is in BetterTransformer mode.
All should be good if the model is reversed.
"""
model_id = "hf-internal-testing/tiny-random-t5"
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
model = model.to_bettertransformer()
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(ValueError):
model.save_pretrained(tmpdirname)
model = model.reverse_bettertransformer()
model.save_pretrained(tmpdirname)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment