"vllm/multimodal/processing/processor.py" did not exist on "7e6f12381092cd0b457ad10c57fa48ca73c415a7"
quantized_kvcache.md 6.13 KB
Newer Older
1
# Quantized KV Cache
2

3
## FP8 KV Cache Overview
4

5
Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows.
6

7
> **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values.
8

9
### Supported FP8 KV-Cache Quantization Schemes
10

11
vLLM supports two main quantization strategies for the FP8 KV-cache:
12

13
14
15
16
- **Per-tensor quantization:**  
  A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`)
- **Per-attention-head quantization:**  
  Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`.
17

18
19
> **Note:**  
> Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**.
20

21
### Scale Calibration Approaches
22

23
You can configure how the quantization scales are computed in vLLM using three different approaches:
24

25
26
27
28
29
30
31
1. **No calibration (default scales):**  
   All quantization scales are set to `1.0`.  
   _Configure with:_  
   ```python
   kv_cache_dtype="fp8"
   calculate_kv_scales=False
   ```
32

33
34
35
36
37
38
39
2. **Random token calibration (on-the-fly):**  
   Scales are automatically estimated from a single batch of random tokens during warmup and then fixed.  
   _Configure with:_  
   ```python
   kv_cache_dtype="fp8"
   calculate_kv_scales=True
   ```
40

41
42
43
44
3. **[Recommended] Calibration with a dataset (via llm-compressor):**  
   Scales are estimated using a curated calibration dataset for maximum accuracy.  
   This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library.  
   _See example below!_
45

46
#### Additional `kv_cache_dtype` Options
47

48
49
50
- `kv_cache_dtype="auto"`: Use the model's default data type
- `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs)
- `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+
51

52
---
53

54
## Examples
55

56
### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`)
57

58
All quantization scales are set to 1.0.
59

60
61
```python
from vllm import LLM, SamplingParams
62

63
64
65
66
67
68
69
70
71
72
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
    model="meta-llama/Llama-2-7b-chat-hf",
    kv_cache_dtype="fp8",
    calculate_kv_scales=False,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
73

74
---
75

76
### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`)
77

78
Scales are automatically estimated from a single batch of tokens during warmup.
79

80
81
```python
from vllm import LLM, SamplingParams
82

83
84
85
86
87
88
89
90
91
92
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
    model="meta-llama/Llama-2-7b-chat-hf",
    kv_cache_dtype="fp8",
    calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
93

94
---
95

96
### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)**
97

98
For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization.
99

100
#### Install the required package
101

102
```bash
103
104
105
pip install llmcompressor
```

106
#### Example: Quantize Llama Attention & KV Cache to FP8
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
```python
"""
Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy)
using llm-compressor one-shot calibration.
"""

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs

# -----------------------------
# Config
# -----------------------------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
STRATEGY = "tensor"       # or "attn_head"
NUM_CALIB_SAMPLES = 512   # Good starting value
MAX_SEQ_LEN = 2048

# -----------------------------
# Helpers
# -----------------------------
def process_and_tokenize(example, tokenizer: AutoTokenizer):
    """Convert chat messages to tokens."""
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return tokenizer(
        text,
        padding=False,
        max_length=MAX_SEQ_LEN,
        truncation=True,
        add_special_tokens=False,
    )
144

145
146
147
148
149
150
151
152
153
154
155
def build_recipe(strategy: str) -> QuantizationModifier:
    fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
    return QuantizationModifier(
        config_groups={
            "attention": QuantizationScheme(
                targets=["LlamaAttention"],  # Quantize queries: q_scale
                input_activations=fp8_args,
            )
        },
        kv_cache_scheme=fp8_args,           # Quantize KV cache: k/v_scale
    )
156

157
158
159
160
161
# -----------------------------
# Main
# -----------------------------
def main():
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
162
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
163
164
165
166
167
168
    ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
    ds = ds.shuffle(seed=42)
    ds = ds.map(
        lambda ex: process_and_tokenize(ex, tokenizer),
        remove_columns=ds.column_names,
    )
169

170
    recipe = build_recipe(STRATEGY)
171
172
173
174
    oneshot(
        model=model,
        dataset=ds,
        recipe=recipe,
175
176
        max_seq_length=MAX_SEQ_LEN,
        num_calibration_samples=NUM_CALIB_SAMPLES,
177
178
    )

179
180
181
    save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
    model.save_pretrained(save_dir, save_compressed=True)
    tokenizer.save_pretrained(save_dir)
182

183
184
if __name__ == "__main__":
    main()
185
```
186
187

For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache).