Commit 13c70d30 authored by Steven Liu's avatar Steven Liu
Browse files

clarify

parent b891f80b
......@@ -9,21 +9,26 @@ This guide provides a brief guide on how bitsandbytes supports storing quantized
## Quantized data storage
FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase.
FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32.
For example, you can configure this option in [`transformers.BitsAndBytesConfig`] by setting the `bnb_4bit_quant_storage` parameter.
```py
import torch
import bitsandbytes as bnb
model = bnb.nn.Linear4bit(
input_features,
output_features,
quant_type="fp4",
quant_storage=torch.uint8,
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.bfloat16,
)
```
With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32.
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
)
```
## Training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment