perf_infer_gpu_one.md 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10
11
12
13

鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

14
15
-->

16
# GPU inference
17

18
GPUs are the standard choice of hardware for machine learning, unlike CPUs, because they are optimized for memory bandwidth and parallelism. To keep up with the larger sizes of modern models or to run these large models on existing and older hardware, there are several optimizations you can use to speed up GPU inference. In this guide, you'll learn how to use FlashAttention-2 (a more memory-efficient attention mechanism), BetterTransformer (a PyTorch native fastpath execution), and bitsandbytes to quantize your model to a lower precision. Finally, learn how to use 馃 Optimum to accelerate inference with ONNX Runtime on Nvidia and AMD GPUs.
19
20
21

<Tip>

22
The majority of the optimizations described here also apply to multi-GPU setups!
23
24
25

</Tip>

26
## FlashAttention-2
27

28
<Tip>
29

30
FlashAttention-2 is experimental and may change considerably in future versions.
31

32
</Tip>
33

34
[FlashAttention-2](https://huggingface.co/papers/2205.14135) is a faster and more efficient implementation of the standard attention mechanism that can significantly speedup inference by:
35

36
37
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them
38

39
40
41
FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
42
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
43
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
Saurabh Dash's avatar
Saurabh Dash committed
44
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
Abhi Venigalla's avatar
Abhi Venigalla committed
45
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
46
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
47
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
Arthur's avatar
Arthur committed
48
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
49
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
50
51
52
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
53
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
amyeroberts's avatar
amyeroberts committed
54
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
55
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
Yikang Shen's avatar
Yikang Shen committed
56
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
tomeras91's avatar
tomeras91 committed
57
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
58
59
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
NielsRogge's avatar
NielsRogge committed
60
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
61
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
62
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
Raushan Turganbay's avatar
Raushan Turganbay committed
63
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
64
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
65
66
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
67
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
68
69
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
Ao Tang's avatar
Ao Tang committed
70
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
71
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
Shane A's avatar
Shane A committed
72
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
73
74
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
Gustavo de Rosa's avatar
Gustavo de Rosa committed
75
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
76
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
Jonathan Tow's avatar
Jonathan Tow committed
77
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
RaymondLi0's avatar
RaymondLi0 committed
78
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
Junyang Lin's avatar
Junyang Lin committed
79
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
Yunfei Chu's avatar
Yunfei Chu committed
80
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
Bo Zheng's avatar
Bo Zheng committed
81
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
82
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
83
84
85
86
87
88
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
89
90

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
91

Steven Liu's avatar
Steven Liu committed
92
Before you begin, make sure you have FlashAttention-2 installed.
93

Steven Liu's avatar
Steven Liu committed
94
95
96
97
98
99
100
101
102
103
104
105
<hfoptions id="install">
<hfoption id="NVIDIA">

```bash
pip install flash-attn --no-build-isolation
```

We strongly suggest referring to the detailed [installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) to learn more about supported hardware and data types!

</hfoption>
<hfoption id="AMD">

106
FlashAttention-2 is also supported on AMD GPUs and current support is limited to **Instinct MI210**, **Instinct MI250** and **Instinct MI300**. We strongly suggest using this [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.
Steven Liu's avatar
Steven Liu committed
107
108
109

</hfoption>
</hfoptions>
110

111
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:
112
113
114
115
116
117
118
119
120

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
amyeroberts's avatar
amyeroberts committed
121
122
    model_id,
    torch_dtype=torch.bfloat16,
123
    attn_implementation="flash_attention_2",
124
125
126
)
```

127
<Tip>
128

129
FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.
130

Steven Liu's avatar
Steven Liu committed
131
132
133
<br>

You can also set `use_flash_attention_2=True` to enable FlashAttention-2 but it is deprecated in favor of `attn_implementation="flash_attention_2"`.
amyeroberts's avatar
amyeroberts committed
134

135
</Tip>
136

137
FlashAttention-2 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-2 with 8-bit or 4-bit quantization:
138

139
```py
140
141
142
143
144
145
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

146
# load in 8bit
147
model = AutoModelForCausalLM.from_pretrained(
amyeroberts's avatar
amyeroberts committed
148
    model_id,
149
    load_in_8bit=True,
150
    attn_implementation="flash_attention_2",
151
152
)

153
# load in 4bit
154
model = AutoModelForCausalLM.from_pretrained(
amyeroberts's avatar
amyeroberts committed
155
    model_id,
156
    load_in_4bit=True,
157
    attn_implementation="flash_attention_2",
158
159
160
)
```

161
### Expected speedups
162

163
You can benefit from considerable speedups for inference, especially for inputs with long sequences. However, since FlashAttention-2 does not support computing attention scores with padding tokens, you must manually pad/unpad the attention scores for batched inference when the sequence contains padding tokens. This leads to a significant slowdown for batched generations with padding tokens.
164

165
To overcome this, you should use FlashAttention-2 without padding tokens in the sequence during training (by packing a dataset or [concatenating sequences](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L516) until reaching the maximum sequence length).
166

167
For a single forward pass on [tiiuae/falcon-7b](https://hf.co/tiiuae/falcon-7b) with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:
168

169
170
171
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/falcon-7b-inference-large-seqlen.png">
</div>
172

173
For a single forward pass on [meta-llama/Llama-7b-hf](https://hf.co/meta-llama/Llama-7b-hf) with a sequence length of 4096 and various batch sizes without padding tokens, the expected speedup is:
174

175
176
177
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-7b-inference-large-seqlen.png">
</div>
178

179
For sequences with padding tokens (generating with padding tokens), you need to unpad/pad the input sequences to correctly compute the attention scores. With a relatively small sequence length, a single forward pass creates overhead leading to a small speedup (in the example below, 30% of the input is filled with padding tokens):
180

181
182
183
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-small-seqlen-padding.png">
</div>
184

185
But for larger sequence lengths, you can expect even more speedup benefits:
186
187
188

<Tip>

189
FlashAttention is more memory efficient, meaning you can train on much larger sequence lengths without running into out-of-memory issues. You can potentially reduce memory usage up to 20x for larger sequence lengths. Take a look at the [flash-attention](https://github.com/Dao-AILab/flash-attention) repository for more details.
190

191
</Tip>
Younes Belkada's avatar
Younes Belkada committed
192

193
194
195
<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>
196

Steven Liu's avatar
Steven Liu committed
197
## PyTorch scaled dot product attention
198

199
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
200

Steven Liu's avatar
Steven Liu committed
201
For now, Transformers supports SDPA inference and training for the following architectures:
202
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
203
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
204
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
205
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
206
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
Saurabh Dash's avatar
Saurabh Dash committed
207
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
Abhi Venigalla's avatar
Abhi Venigalla committed
208
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
209
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
210
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
211
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
212
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
Arthur's avatar
Arthur committed
213
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
214
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
Abhi Venigalla's avatar
Abhi Venigalla committed
215
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
216
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
Yikang Shen's avatar
Yikang Shen committed
217
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
tomeras91's avatar
tomeras91 committed
218
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
219
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
Shane A's avatar
Shane A committed
220
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
Pablo Montalvo's avatar
Pablo Montalvo committed
221
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
JB (Don)'s avatar
JB (Don) committed
222
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
223
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
224
225
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
226
227
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
228
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
RaymondLi0's avatar
RaymondLi0 committed
229
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
Junyang Lin's avatar
Junyang Lin committed
230
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
Yunfei Chu's avatar
Yunfei Chu committed
231
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
Bo Zheng's avatar
Bo Zheng committed
232
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
233
234
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
Ao Tang's avatar
Ao Tang committed
235
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
236
237
238
239
240
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
241
242
243
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
244
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
245
246
247
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
248
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
249

250

Steven Liu's avatar
Steven Liu committed
251
252
<Tip>

253
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
254

Steven Liu's avatar
Steven Liu committed
255
256
</Tip>

257
258
259
260
261
262
263
<Tip>

SDPA does not support certain sets of attention parameters, such as `head_mask` and `output_attentions=True`.
In that case, you should see a warning message and we will fall back to the (slower) eager implementation.

</Tip>

Steven Liu's avatar
Steven Liu committed
264
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
265
266
267
268
269
270

```diff
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
271
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
272
273
274
275
276
277
278
279
280
281

input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

Steven Liu's avatar
Steven Liu committed
282
If you see a bug with the traceback below, try using the nightly version of PyTorch which may have broader coverage for FlashAttention:
283
284

```bash
285
RuntimeError: No available kernel. Aborting execution.
286

287
# install PyTorch nightly
288
289
290
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```

291
292
293
294
## BetterTransformer

<Tip warning={true}>

Steven Liu's avatar
Steven Liu committed
295
Some BetterTransformer features are being upstreamed to Transformers with default support for native `torch.nn.scaled_dot_product_attention`. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to natively support SDPA in Transformers.
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

</Tip>

<Tip>

Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 馃 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.

</Tip>

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.

Before you start, make sure you have 馃 Optimum [installed](https://huggingface.co/docs/optimum/installation).

Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:

```python
model = model.to_bettertransformer()
```

You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:

```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```

327
## bitsandbytes
328

329
bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
330

Stas Bekman's avatar
Stas Bekman committed
331
Make sure you have bitsandbytes and 馃 Accelerate installed:
332

333
334
335
```bash
# these versions support 8-bit and 4-bit
pip install bitsandbytes>=0.39.0 accelerate>=0.20.0
336

337
338
339
# install Transformers
pip install transformers
```
340

341
### 4-bit
342

343
To load a model in 4-bit for inference, use the `load_in_4bit` parameter. The `device_map` parameter is optional, but we recommend setting it to `"auto"` to allow 馃 Accelerate to automatically and efficiently allocate the model given the available resources in the environment.
344
345
346
347
348

```py
from transformers import AutoModelForCausalLM

model_name = "bigscience/bloom-2b5"
349
model_4bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
350
351
```

352
To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 600MB of memory to the first GPU and 1GB of memory to the second GPU:
353
354
355
356

```py
max_memory_mapping = {0: "600MB", 1: "1GB"}
model_name = "bigscience/bloom-3b"
357
model_4bit = AutoModelForCausalLM.from_pretrained(
358
359
360
361
    model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
)
```

362
### 8-bit
363

364
<Tip>
365

366
If you're curious and interested in learning more about the concepts underlying 8-bit quantization, read the [Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes](https://huggingface.co/blog/hf-bitsandbytes-integration) blog post.
367
368
369

</Tip>

370
To load a model in 8-bit for inference, use the `load_in_8bit` parameter. The `device_map` parameter is optional, but we recommend setting it to `"auto"` to allow 馃 Accelerate to automatically and efficiently allocate the model given the available resources in the environment:
371

372
```py
373
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
374

375
model_name = "bigscience/bloom-2b5"
376
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
377
378
```

379
If you're loading a model in 8-bit for text generation, you should use the [`~transformers.GenerationMixin.generate`] method instead of the [`Pipeline`] function which is not optimized for 8-bit models and will be slower. Some sampling strategies, like nucleus sampling, are also not supported by the [`Pipeline`] for 8-bit models. You should also place all inputs on the same device as the model:
380
381

```py
382
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
383
384
385

model_name = "bigscience/bloom-2b5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
386
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=BitsAndBytesConfig(load_in_8bit=True))
387

388
prompt = "Hello, my llama is cute"
389
390
391
392
393
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
generated_ids = model.generate(**inputs)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
```

394
To load a model in 4-bit for inference with multiple GPUs, you can control how much GPU RAM you want to allocate to each GPU. For example, to distribute 1GB of memory to the first GPU and 2GB of memory to the second GPU:
395
396
397
398
399
400
401
402
403

```py
max_memory_mapping = {0: "1GB", 1: "2GB"}
model_name = "bigscience/bloom-3b"
model_8bit = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
)
```

404
<Tip>
405

406
Feel free to try running a 11 billion parameter [T5 model](https://colab.research.google.com/drive/1YORPWx4okIHXnjW7MSAidXN29mPVNT7F?usp=sharing) or the 3 billion parameter [BLOOM model](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing) for inference on Google Colab's free tier GPUs!
407

408
</Tip>
409

410
## 馃 Optimum
411

412
413
<Tip>

414
Learn more details about using ORT with 馃 Optimum in the [Accelerated inference on NVIDIA GPUs](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#accelerated-inference-on-nvidia-gpus) and [Accelerated inference on AMD GPUs](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/amdgpu#accelerated-inference-on-amd-gpus) guides. This section only provides a brief and simple example.
415
416
417

</Tip>

418
ONNX Runtime (ORT) is a model accelerator that supports accelerated inference on Nvidia GPUs, and AMD GPUs that use [ROCm](https://www.amd.com/en/products/software/rocm.html) stack. ORT uses optimization techniques like fusing common operations into a single node and constant folding to reduce the number of computations performed and speedup inference. ORT also places the most computationally intensive operations on the GPU and the rest on the CPU to intelligently distribute the workload between the two devices.
419

420
ORT is supported by 馃 Optimum which can be used in 馃 Transformers. You'll need to use an [`~optimum.onnxruntime.ORTModel`] for the task you're solving, and specify the `provider` parameter which can be set to either [`CUDAExecutionProvider`](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#cudaexecutionprovider), [`ROCMExecutionProvider`](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/amdgpu) or [`TensorrtExecutionProvider`](https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#tensorrtexecutionprovider). If you want to load a model that was not yet exported to ONNX, you can set `export=True` to convert your model on-the-fly to the ONNX format:
421
422
423
424
425

```py
from optimum.onnxruntime import ORTModelForSequenceClassification

ort_model = ORTModelForSequenceClassification.from_pretrained(
426
  "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
427
428
429
430
  export=True,
  provider="CUDAExecutionProvider",
)
```
431

432
433
434
435
436
437
Now you're free to use the model for inference:

```py
from optimum.pipelines import pipeline
from transformers import AutoTokenizer

438
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
439
440
441
442
443
444
445
446

pipeline = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
result = pipeline("Both the music and visual were astounding, not to mention the actors performance.")
```

## Combine optimizations

It is often possible to combine several of the optimization techniques described above to get the best inference performance possible for your model. For example, you can load a model in 4-bit, and then enable BetterTransformer with FlashAttention:
447
448
449
450
451

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

452
# load model in 4-bit
453
454
455
456
457
458
459
460
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", quantization_config=quantization_config)

461
462
463
# enable BetterTransformer
model = model.to_bettertransformer()

464
465
466
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

467
# enable FlashAttention
468
469
470
471
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
472
```