torchao.md 32.6 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# torchao

[![Open In Colab: Torchao Demo](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/quantization/torchao.ipynb)

[torchao](https://github.com/pytorch/ao) is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) for even faster inference and training.

See the table below for additional torchao features.

| Feature | Description |
|--------|-------------|
| **Quantization Aware Training (QAT)** | Train quantized models with minimal accuracy loss (see [QAT README](https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md)) |
| **Float8 Training** | High-throughput training with float8 formats (see [torchtitan](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md) and [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/low_precision_training#configuring-torchao) docs) |
| **Sparsity Support** | Semi-structured (2:4) sparsity for faster inference (see [Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity](https://pytorch.org/blog/accelerating-neural-network-training/) blog post) |
| **Optimizer Quantization** | Reduce optimizer state memory with 4 and 8-bit variants of Adam |
| **KV Cache Quantization** | Enables long context inference with lower memory (see [KV Cache Quantization](https://github.com/pytorch/ao/blob/main/torchao/_models/llama/README.md)) |
| **Custom Kernels Support** | use your own `torch.compile` compatible ops |
| **FSDP2** | Composable with FSDP2 for training|

> [!TIP]
> Refer to the torchao [README.md](https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization) for more details about the library.

torchao supports the [quantization techniques](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) below.

- A16W8 Float8 Dynamic Quantization
- A16W8 Float8 WeightOnly Quantization
- A8W8 Int8 Dynamic Quantization
- A16W8 Int8 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization + 2:4 Sparsity
- Autoquantization

torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.

Check the table below to see if your hardware is compatible.

| Component | Compatibility |
|----------|----------------|
| CUDA Versions | ✅ cu118, cu126, cu128 |
| XPU Versions | ✅ pytorch2.8 |
| CPU | ✅ change `device_map="cpu"` (see examples below) |

Install torchao from PyPi or the PyTorch index with the following commands.

<hfoptions id="install torchao">
<hfoption id="PyPi">

```bash
# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
# Stable release from Pypi which will default to CUDA 12.6
pip install --upgrade torchao transformers
```

</hfoption>
<hfoption id="PyTorch Index">
Stable Release from the PyTorch index

```bash
pip install torchao --index-url https://download.pytorch.org/whl/cu126 # options are cpu/cu118/cu126/cu128
```

</hfoption>
</hfoptions>

If your torchao version is below 0.10.0, you need to upgrade it, please refer to the [deprecation notice](#deprecation-notice) for more details.

## Quantization examples

TorchAO provides a variety of quantization configurations. Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.

For a complete list of available configurations, see the [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).

You can manually choose the quantization types and settings or automatically select the quantization types.

Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize (for int8 weight only and int4 weight only). Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.

We'll show examples for recommended quantization methods based on hardwares, e.g. A100 GPU, H100 GPU, CPU.

### H100 GPU

<hfoptions id="examples-H100-GPU">
<hfoption id="float8-dynamic-and-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig

quant_config = Float8DynamicActivationFloat8WeightConfig()
# or float8 weight only quantization
# quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
<hfoption id="int4-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import GemliteUIntXWeightOnlyConfig

# We integrated with gemlite, which optimizes for batch size N on A100 and H100
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

</hfoption>
<hfoption id="int4-weight-only-24sparse">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout

quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Sparse-Llama-3.1-8B-2of4",
    dtype=torch.float16,
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

### A100 GPU

<hfoptions id="examples-A100-GPU">
<hfoption id="int8-dynamic-and-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>

<hfoption id="int4-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import GemliteUIntXWeightOnlyConfig, Int4WeightOnlyConfig

# For batch size N, we recommend gemlite, which may require autotuning
# default is 4 bit, 8 bit is also supported by passing `bit_width=8`
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)

# For batch size 1, we also have custom tinygemm kernel that's only optimized for this
# We can set `use_hqq` to `True` for better accuracy
# quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)

quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

</hfoption>
<hfoption id="int4-weight-only-24sparse">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout

quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuracy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
    "RedHatAI/Sparse-Llama-3.1-8B-2of4",
    dtype=torch.float16,
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

### Intel XPU

<hfoptions id="examples-Intel-XPU">
<hfoption id="int8-dynamic-and-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>

<hfoption id="int4-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain


quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

### CPU

<hfoptions id="examples-CPU">
<hfoption id="int8-dynamic-and-weight-only">

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig

quant_config = Int8DynamicActivationInt8WeightConfig()
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt")

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
<hfoption id="int4-weight-only">

> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`.

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import Int4CPULayout

quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt")

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

### Per Module Quantization

#### 1. Skip quantization for certain layers

With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "meta-llama/Llama-3.1-8B-Instruct"

from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
config = Int4WeightOnlyConfig(group_size=128)

# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to(quantized_model.device.type)
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```

#### 2. Quantizing different layers with different quantization configs (no regex)

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "facebook/opt-125m"

from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType

weight_dtype = torch.int8
granularity = PerAxis(0)
mapping_type = MappingType.ASYMMETRIC
embedding_config = IntxWeightOnlyConfig(
    weight_dtype=weight_dtype,
    granularity=granularity,
    mapping_type=mapping_type,
)
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
# set `include_embedding` to True in order to include embedding in quantization
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", dtype=torch.bfloat16, quantization_config=quantization_config)
print("quantized model:", quantized_model)
# make sure embedding is quantized
print("embed_tokens weight:", quantized_model.model.decoder.embed_tokens.weight)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Manual Testing
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, cache_implementation="static")
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```

#### 3. Quantizing different layers with different quantization configs (with regex)

We can also use regex to specify the config for all modules that has `module_fqn` that
matches the regex, all regex should start with `re:`, for example `re:layers\..*\.gate_proj` will
match all layers like `layers.0.gate_proj`. See [here](https://github.com/pytorch/ao/blob/2fe0ca0899c730c528efdbec8886feaa38879f39/torchao/quantization/quant_api.py#L2392) for docs.

```py
import logging

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

# Configure logging to see warnings and debug information
logging.basicConfig(
    level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)

# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)

model_id = "facebook/opt-125m"

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Int4WeightOnlyConfig,
    IntxWeightOnlyConfig,
    PerRow,
    PerAxis,
    ModuleFqnToConfig,
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)

float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))

qconfig_dict = {
    # highest priority
    "model.decoder.layers.3.self_attn.q_proj": int4wo,
    "model.decoder.layers.3.self_attn.k_proj": int4wo,
    "model.decoder.layers.3.self_attn.v_proj": int4wo,
    # vllm
    "model.decoder.layers.3.self_attn.qkv_proj": int4wo,

    "re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
    # this should not take effect and we'll fallback to _default
    # since no full mach (missing `j` in the end)
    "re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
    # vllm
    "re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,

    "_default": intxwo,
}
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)

# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
).to("cuda")
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)

correct_output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])


# Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
    save_to,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
    # quantization_config=quantization_config,
)

generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])

assert(correct_output_text == output_text)
```

### Autoquant

If you want to automatically choose a quantization type for quantizable layers (`nn.Linear`) you can use the [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API.

The `autoquant` API automatically chooses a quantization type by micro-benchmarking on input type and shape and compiling a single linear layer.

Note: autoquant is for GPU only right now.

Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
# explicitly call `finalize_autoquant` (may be refactored and removed in the future)
quantized_model.finalize_autoquant()
print(tokenizer.decode(output[0], skip_special_tokens=True))
```

## Serialization

torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.

To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).

<hfoptions id="serialization-examples">
<hfoption id="save-locally">

```py
# don't serialize model with Safetensors
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)
```

</hfoption>
<hfoption id="push-to-huggingface-hub">

```py
# don't serialize model with Safetensors
USER_ID = "your_huggingface_user_id"
REPO_ID = "llama3-8b-int4wo-128"
quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serialization=False)
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
```

</hfoption>
</hfoptions>

## Loading quantized models

Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA or XPU.

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig

quant_config = Int8WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)
# save the quantized model
output_dir = "llama-3.1-8b-torchao-int8"
quantized_model.save_pretrained(output_dir, safe_serialization=False)

# reload the quantized model
reloaded_model = AutoModelForCausalLM.from_pretrained(
    output_dir,
    device_map="auto",
    dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)

output = reloaded_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

```

For int4, the model can only be loaded on the same device it was quantized on because the layout is specific to the device. The example below demonstrates quantizing and loading a model on the CPU.

```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import Int4CPULayout

quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quantization_config = TorchAoConfig(quant_type=quant_config)

# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    dtype="auto",
    device_map="cpu",
    quantization_config=quantization_config
)
# save the quantized model
output_dir = "llama-3.1-8b-torchao-int4-cpu"
quantized_model.save_pretrained(output_dir, safe_serialization=False)

# reload the quantized model
reloaded_model = AutoModelForCausalLM.from_pretrained(
    output_dir,
    device_map="cpu",
    dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt")

output = reloaded_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

```

## ⚠️ Deprecation Notice

> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
>
> Please use the new `AOBaseConfig`-based approach instead:
>
> ```python
> # Old way (deprecated)
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
>
> # New way (recommended)
> from torchao.quantization import Int4WeightOnlyConfig
> quant_config = Int4WeightOnlyConfig(group_size=128)
> quantization_config = TorchAoConfig(quant_type=quant_config)
> ```
>
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
>
> [Migration Guide](#migration-guide)
>
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
>
> | Old String API | New `AOBaseConfig` API |
> |----------------|------------------------|
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
>
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).

## Resources

For a better sense of expected performance, view the [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) for various models with CUDA and XPU backends. You can also run the code below to benchmark a model yourself.

```py
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable

def benchmark_fn(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
```

> [!TIP]
> For best performance, you can use recommended settings by calling `torchao.quantization.utils.recommended_inductor_config_setter()`

Refer to [Other Available Quantization Techniques](https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques) for more examples and documentation.

## Issues

If you encounter any issues with the Transformers integration, please open an issue on the [Transformers](https://github.com/huggingface/transformers/issues) repository. For issues directly related to torchao, please open an issue on the [torchao](https://github.com/pytorch/ao/issues) repository.