Unverified Commit 44f08af3 authored by Eldar Kurtić's avatar Eldar Kurtić Committed by GitHub
Browse files

Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)


Signed-off-by: default avatarEldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: default avatareldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
parent 955b43a5
......@@ -202,7 +202,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int block_size, const float* k_scale, const float* v_scale,
const int kv_scale_stride) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
......@@ -226,21 +227,23 @@ __global__ void reshape_and_cache_flash_kernel(
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
if (is_contiguous_heads && kv_scale_stride == 0) {
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
// kv cache: [num_blocks, block_size, num_heads, head_size]
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout
// kv cache: [num_blocks, block_size, num_heads, head_size]
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);
} else {
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
......@@ -256,6 +259,16 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: k_scale[head * kv_scale_stride];
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: v_scale[head * kv_scale_stride];
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
......@@ -605,7 +618,8 @@ void reshape_and_cache(
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
......@@ -614,8 +628,9 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
......@@ -638,6 +653,12 @@ void reshape_and_cache_flash(
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
......
# Quantized KV Cache
## FP8 KV Cache
## FP8 KV Cache Overview
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows.
### FP8 Formats
> **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values.
[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
### Supported FP8 KV-Cache Quantization Schemes
- E5M2 (5 exponent bits and 2 mantissa bits)
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
vLLM supports two main quantization strategies for the FP8 KV-cache:
The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
- **Per-tensor quantization:**
A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`)
- **Per-attention-head quantization:**
Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`.
### Current Limitations
> **Note:**
> Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**.
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
### Scale Calibration Approaches
### How FP8 KV Cache Works
You can configure how the quantization scales are computed in vLLM using three different approaches:
The FP8 KV cache implementation follows this workflow:
1. **No calibration (default scales):**
All quantization scales are set to `1.0`.
_Configure with:_
```python
kv_cache_dtype="fp8"
calculate_kv_scales=False
```
1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
2. **Random token calibration (on-the-fly):**
Scales are automatically estimated from a single batch of random tokens during warmup and then fixed.
_Configure with:_
```python
kv_cache_dtype="fp8"
calculate_kv_scales=True
```
This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
3. **[Recommended] Calibration with a dataset (via llm-compressor):**
Scales are estimated using a curated calibration dataset for maximum accuracy.
This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library.
_See example below!_
### Performance Impact
#### Additional `kv_cache_dtype` Options
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
- `kv_cache_dtype="auto"`: Use the model's default data type
- `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs)
- `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+
- Processing longer context lengths for individual requests, or
- Handling more concurrent request batches
---
However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
## Examples
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`)
## Usage Example
All quantization scales are set to 1.0.
Here is an example of how to enable FP8 quantization:
```python
from vllm import LLM, SamplingParams
??? code
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=False,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
```python
# To calculate kv cache scales on the fly enable the calculate_kv_scales
# parameter
---
from vllm import LLM, SamplingParams
### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`)
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
Scales are automatically estimated from a single batch of tokens during warmup.
```python
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(
model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
calculate_kv_scales=True,
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
The `kv_cache_dtype` argument specifies the data type for KV cache storage:
- `"auto"`: Uses the model's default "unquantized" data type
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
- `"fp8_e5m2"`: Supported on CUDA 11.8+
)
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
## Calibrated Scales for Better Accuracy
---
For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)**
### Installation
For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization.
First, install the required dependencies:
#### Install the required package
```bash
pip install llmcompressor
```
### Example Usage
#### Example: Quantize Llama Attention & KV Cache to FP8
Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):
??? code
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
# Select model and load it
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Select calibration dataset
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
# Configure calibration parameters
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
MAX_SEQUENCE_LENGTH = 2048
# Load and preprocess dataset
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def process_and_tokenize(example):
```python
"""
Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy)
using llm-compressor one-shot calibration.
"""
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
# -----------------------------
# Config
# -----------------------------
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
STRATEGY = "tensor" # or "attn_head"
NUM_CALIB_SAMPLES = 512 # Good starting value
MAX_SEQ_LEN = 2048
# -----------------------------
# Helpers
# -----------------------------
def process_and_tokenize(example, tokenizer: AutoTokenizer):
"""Convert chat messages to tokens."""
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
max_length=MAX_SEQ_LEN,
truncation=True,
add_special_tokens=False,
)
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
# Configure quantization settings
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""
# Apply quantization
def build_recipe(strategy: str) -> QuantizationModifier:
fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
return QuantizationModifier(
config_groups={
"attention": QuantizationScheme(
targets=["LlamaAttention"], # Quantize queries: q_scale
input_activations=fp8_args,
)
},
kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale
)
# -----------------------------
# Main
# -----------------------------
def main():
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
ds = ds.shuffle(seed=42)
ds = ds.map(
lambda ex: process_and_tokenize(ex, tokenizer),
remove_columns=ds.column_names,
)
recipe = build_recipe(STRATEGY)
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
max_seq_length=MAX_SEQ_LEN,
num_calibration_samples=NUM_CALIB_SAMPLES,
)
# Save quantized model: Llama-3.1-8B-Instruct-FP8-KV
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
model.save_pretrained(save_dir, save_compressed=True)
tokenizer.save_pretrained(save_dir)
When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.
```python
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
if __name__ == "__main__":
main()
```
For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache).
......@@ -8,6 +8,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_dequantize
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
......@@ -19,6 +20,7 @@ NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]
KV_SCALE_TYPES = ["tensor", "attn_head"]
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
......@@ -170,6 +172,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@pytest.mark.parametrize("kv_scale_type", KV_SCALE_TYPES)
@pytest.mark.parametrize("implementation", RESHAPE_FLASH_IMPLEMENTATIONS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
......@@ -184,6 +187,7 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
kv_scale_type: str,
implementation: str,
) -> None:
set_random_seed(seed)
......@@ -193,6 +197,9 @@ def test_reshape_and_cache_flash(
if implementation == "triton" and kv_cache_layout == "HND":
pytest.skip("Triton implementation only supports NHD layout.")
if kv_scale_type == "attn_head" and implementation != "cuda":
pytest.skip("Only CUDA implementation supports attn_head scaling.")
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
......@@ -220,8 +227,12 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
if kv_scale_type == "tensor":
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
else: # "attn_head"
k_scale = (key.amax(dim=(0, 2)) / 64.0).to(torch.float32)
v_scale = (value.amax(dim=(0, 2)) / 64.0).to(torch.float32)
def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
......@@ -230,15 +241,27 @@ def test_reshape_and_cache_flash(
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
def convert_fp8_local(output, input, scale, kv_dtype):
fp8_input = input.view(torch.float8_e4m3fn)
if scale.numel() == 1: # per-tensor
result = scaled_dequantize(
fp8_input.flatten(0, 2), scale, group_shape=None, out_dtype=output.dtype
).reshape(*input.shape)
else: # per-head: broadcast scale along the head dimension
# Original code uses dim 2 for NHD, dim 1 for HND
if kv_cache_layout == "NHD":
result = fp8_input.to(output.dtype) * scale.view(1, 1, -1, 1)
else:
result = fp8_input.to(output.dtype) * scale.view(1, -1, 1, 1)
output.copy_(result)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
)
convert_fp8_local(cloned_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
convert_fp8_local(
cloned_value_cache, value_cache_compact, v_scale, kv_cache_dtype
)
else:
cloned_key_cache = key_cache_compact.clone()
......@@ -289,15 +312,13 @@ def test_reshape_and_cache_flash(
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
ops.convert_fp8(
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
)
convert_fp8_local(result_key_cache, key_cache_compact, k_scale, kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
ops.convert_fp8(
convert_fp8_local(
result_value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype,
v_scale,
kv_cache_dtype,
)
# Run the reference implementation.
......
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
......@@ -360,9 +361,26 @@ def test_compressed_tensors_fp8(vllm_runner):
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, enforce_eager=True, kv_cache_dtype="fp8") as llm:
def test_compressed_tensors_kv_cache_fp8_per_tensor(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-tensor"
with vllm_runner(model_path) as llm:
output = llm.generate_greedy("Hello world!", max_tokens=4)
assert output
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-attn_head"
try:
fa_version = get_flash_attn_version()
except Exception:
pytest.skip("This test requires FlashAttention backend.")
if fa_version is None or fa_version < 3:
pytest.skip("This test requires FlashAttention version >= 3.")
with vllm_runner(model_path, attention_config={"backend": "FLASH_ATTN"}) as llm:
output = llm.generate_greedy("Hello world!", max_tokens=4)
assert output
......
......@@ -75,13 +75,16 @@ def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) ->
layer._v_scale_float = 1.0
layer._prob_scale_float = 1.0
# Initialize q/k/v range constants used by calc_kv_scales
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
kv_cache_dtype: str,
calculate_kv_scales: bool,
) -> None:
"""Initializes KV cache scaling factors and quantization method.
......@@ -94,16 +97,10 @@ def _init_kv_cache_quant(
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
kv_cache_dtype: The KV cache data type string.
calculate_kv_scales: Whether to calculate KV scales dynamically.
"""
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = calculate_kv_scales
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
# Note [Register q/k/v/prob scales in state dict]
# When calling model.to(device), only parameters/buffers in state dict are
......@@ -133,7 +130,7 @@ def _init_kv_cache_quant(
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if kv_cache_dtype == "fp8_e5m2":
if layer.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
......@@ -197,9 +194,20 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_dtype = "fp8"
calculate_kv_scales = False
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
......@@ -208,15 +216,6 @@ class Attention(nn.Module, AttentionLayerBase):
self.quant_config = quant_config
self.layer_name = prefix
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)
self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
......@@ -318,18 +317,24 @@ class Attention(nn.Module, AttentionLayerBase):
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize q/k/v range constants.
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
# for attn backends supporting query quantization
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
"fp8"
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
is_per_head = (
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
)
block_size = self.head_size * self.num_heads // self.num_kv_heads
self.query_quant = QuantFP8(
static=True,
group_shape=GroupShape(-1, block_size)
if is_per_head
else GroupShape.PER_TENSOR,
)
def forward(
self,
......@@ -524,13 +529,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.quant_config = quant_config
# Initialize KV cache quantization attributes
_init_kv_cache_quant(
self,
self.quant_config,
self.layer_name,
kv_cache_dtype,
calculate_kv_scales,
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)
dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
......@@ -19,6 +20,10 @@ from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
......@@ -87,6 +92,8 @@ class CompressedTensorsConfig(QuantizationConfig):
kv_cache_scheme: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
total_num_heads: int | None = None,
total_num_kv_heads: int | None = None,
):
super().__init__()
self.ignore = ignore
......@@ -97,6 +104,8 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)
......@@ -200,13 +209,29 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
# We keep only config groups which are not doing Attention quantization
# because Attention quantization on its own is not supported by vLLM.
# It is coupled with KV-cache quantization, and if scales are present in the
# checkpoint, they will be used properly.
grps_without_attn_quant = {}
for k, v in config["config_groups"].items():
# e.g. LlamaAttention, Qwen3Attention, etc.
if len(v["targets"]) == 1 and v["targets"][0].endswith("Attention"):
logger.warning(
"Skipping CompressedTensors config group for %s. Attention quant "
"is coupled with KV-cache quantization in vLLM.",
v["targets"][0],
)
continue
grps_without_attn_quant[k] = v
config["config_groups"] = grps_without_attn_quant
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
......@@ -215,7 +240,10 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
transform_config=config.get("transform_config"),
kv_cache_scheme=config.get("kv_cache_scheme"),
total_num_heads=config.get("total_num_heads"),
total_num_kv_heads=config.get("total_num_kv_heads"),
)
@classmethod
......@@ -791,22 +819,6 @@ class CompressedTensorsConfig(QuantizationConfig):
return None
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
......@@ -965,12 +977,16 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
f"received num_bits={num_bits}, type={type_}"
)
# TODO: delegate validation to compressed-tensors library so that we have a
# single source of truth. Right now this is not possible until the next release
# of compressed-tensors.
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
supported_strategies = ("tensor", "attn_head")
if strategy not in supported_strategies:
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}"
"Invalid strategy for compressed-tensors KV cache. "
f"Expected strategies: {supported_strategies}, found strategy:"
f" {strategy}"
)
is_symmetric = kv_cache_scheme.get("symmetric")
......@@ -980,3 +996,133 @@ class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}"
)
def create_weights(self, layer: torch.nn.Module):
"""
Initialize placeholder scales and zero points to enable loading of
quantized params from compressed-tensors checkpoints.
"""
strategy = None # for backward compatibility
if (
hasattr(self.quant_config, "kv_cache_scheme")
and self.quant_config.kv_cache_scheme is not None
):
strategy = self.quant_config.kv_cache_scheme["strategy"]
if strategy == "attn_head":
assert layer.impl.supports_per_head_quant_scales, (
f"Layer {layer.__class__.__name__} with implementation "
f"{layer.impl.__class__.__name__} does not support per-head scales."
)
n_scales = int(layer.num_kv_heads)
else:
n_scales = 1
layer.k_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.v_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
layer.q_scale = torch.nn.Parameter(
torch.ones(n_scales, requires_grad=False, dtype=torch.float32)
)
# Zero points are not used in vLLM as currently only symmetric quantization is
# supported. We need to create them here to enable loading of llm-compressor
# checkpoints which contain them irrespective of the symmetric/asymmetric
# scheme used during quantization.
layer.k_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.v_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
layer.q_zero_point = torch.nn.Parameter(
torch.zeros(n_scales, requires_grad=False)
)
# TP-aware loading for attn_head strategy follows attention head partitioning:
# - q_scale is partitioned over query heads.
# - k/v_scale is partitioned over kv heads when total_kv_heads >= tp_size,
# and replicated when total_kv_heads < tp_size.
if strategy == "attn_head":
def _tp_aware_loader(
param: torch.Tensor,
loaded_weight: torch.Tensor,
kind: Literal["q", "k", "v"],
param_type: Literal["scale", "zero_point"],
):
# Zero-points are not used as vLLM only supports symmetric quantization
if param_type == "zero_point":
return
# LLM-Compressor stores scales as 3D tensors of shape [num_heads, 1, 1]
loaded_weight = loaded_weight.flatten()
# FlashAttn expects [num_kv_heads] instead of [num_heads] for q_scale.
# We reduce by taking the max scale in each attention head group.
if kind == "q":
reduction_factor = (
self.quant_config.total_num_heads # type: ignore[attr-defined]
// self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
)
loaded_weight = torch.amax(
loaded_weight.view(-1, reduction_factor), dim=1
)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if layer.num_kv_heads * tp_size == self.quant_config.total_num_kv_heads: # type: ignore[attr-defined]
# heads evenly distributed
loaded_weight = loaded_weight[
tp_rank * layer.num_kv_heads : (tp_rank + 1)
* layer.num_kv_heads
]
else:
# heads replicated to match TP size
assert layer.num_kv_heads == 1
replicas = tp_size // self.quant_config.total_num_kv_heads # type: ignore[attr-defined]
shard_rank = tp_rank // replicas
loaded_weight = loaded_weight[shard_rank : shard_rank + 1]
param.data.copy_(loaded_weight.to(dtype=param.dtype))
layer.q_scale.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="scale"
)
layer.k_scale.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="scale"
)
layer.v_scale.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="scale"
)
layer.q_zero_point.weight_loader = partial(
_tp_aware_loader, kind="q", param_type="zero_point"
)
layer.k_zero_point.weight_loader = partial(
_tp_aware_loader, kind="k", param_type="zero_point"
)
layer.v_zero_point.weight_loader = partial(
_tp_aware_loader, kind="v", param_type="zero_point"
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Override the default vLLM placeholder scales with the llm-compressor loaded
scales. Zero points are not used as only symmetric quantization is supported.
"""
layer._k_scale = layer.k_scale
layer._v_scale = layer.v_scale
layer._q_scale = layer.q_scale
# Discard all placeholders.
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.k_zero_point
del layer.v_zero_point
del layer.q_zero_point
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
group_broadcast,
prep_scale_for_group_broadcast,
)
from vllm.platforms import current_platform
......@@ -40,7 +41,7 @@ class QuantFP8(CustomOp):
"""
:param static: static or dynamic quantization
:param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR,
or arbitrary block size)
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param column_major_scales: For group quantization, output scales in
......@@ -157,6 +158,8 @@ class QuantFP8(CustomOp):
x_max = x.abs().max().unsqueeze(-1).to(torch.float32)
scale = (x_max / _FP8_MAX).clamp(min=_FP8_MIN_SCALING_FACTOR)
else:
scale = prep_scale_for_group_broadcast(scale, x, self.group_shape)
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
......
......@@ -191,6 +191,51 @@ def group_broadcast(t, shape):
return t
def prep_scale_for_group_broadcast(
scale: torch.Tensor,
x: torch.Tensor,
group_shape: GroupShape | None,
) -> torch.Tensor:
"""
Prepare the input quantization scale for group broadcasting.
Args:
scale: The scale tensor (scalar or 1D).
x: Target tensor whose shape determines broadcast dimensions.
group_shape: GroupShape to broadcast over.
Returns:
scale reshaped for correct broadcasting.
"""
if scale.numel() == 1:
# For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
# This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
# where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
# per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# For all other cases, reshape scalar scales to (1, 1) for broadcasting.
return (
scale
if group_shape is not None and group_shape.is_per_tensor()
else scale.reshape(1, 1)
)
if scale.ndim == 1:
assert group_shape is not None, (
"group_shape must be provided to correctly broadcast 1D scale"
)
rows, cols = _normalize_quant_group_shape(x, group_shape)
# Determine broadcasting dimension: either rows or columns match group size
if rows == x.shape[-2]:
scale = scale.unsqueeze(-2)
elif cols == x.shape[-1]:
scale = scale.unsqueeze(-1)
else:
raise ValueError(
f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
f" {x.shape}, group_shape={(rows, cols)}"
)
return scale
# Quantize assuming once scale per group of elements with shape group_shape,
# example group shapes:
# * (-1, -1) for per-tensor quantization
......@@ -241,7 +286,7 @@ def scaled_quantize(
_, fp8_max = get_fp8_min_max()
scale = fp8_max / amax
# Apply scale and convert form:
# Apply scale and convert from:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (
(x_blkd_permd * scale.unsqueeze(-1))
......@@ -261,29 +306,7 @@ def scaled_dequantize(
group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.numel() == 1: # scalar
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(
"if x_s is 1D tensor, group_shape must be provided otherwise "
"its ambiguous which dimension to broadcast x_s to"
)
# unsqueeze the scales for the dimension where we want to broadcast
# across the full extent
if group_shape[0] == x_q.shape[-2]:
x_s = x_s.unsqueeze(-2)
elif group_shape[1] == x_q.shape[-1]:
x_s = x_s.unsqueeze(-1)
else:
raise AssertionError(
"if x_s is a vector we should be broadcasting it to the full "
"extent of one of the dimensions"
)
x_s = prep_scale_for_group_broadcast(x_s, x_q, group_shape)
if group_shape is not None:
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
......
......@@ -246,6 +246,23 @@ def get_quant_config(
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
# Pipe information about heads to enable TP-aware loading of attn_head scales
if (
hf_quant_config is not None
and hf_quant_config.get("quant_method") == "compressed-tensors"
):
if hf_text_config is not None:
n_heads = getattr(hf_text_config, "num_attention_heads", None)
n_kv_heads = getattr(hf_text_config, "num_key_value_heads", None)
else:
n_heads = getattr(model_config.hf_config, "num_attention_heads", None)
n_kv_heads = getattr(model_config.hf_config, "num_key_value_heads", None)
hf_quant_config["total_num_heads"] = n_heads
hf_quant_config["total_num_kv_heads"] = (
n_kv_heads if n_kv_heads is not None else n_heads
)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
......@@ -1157,11 +1174,21 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
# .mixer.attn.{k,v}_scale
(r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([kv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
]
# Check if name ends with k_scale or v_scale
if name.endswith((".k_scale", ".v_scale")):
if name.endswith(
(
".k_scale",
".v_scale",
".q_scale",
".k_zero_point",
".v_zero_point",
".q_zero_point",
)
):
import regex as re
for pattern, replacement in scale_mapping_patterns:
......
......@@ -437,7 +437,7 @@ class ApertusModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
......
......@@ -303,7 +303,7 @@ class ArceeModel(nn.Module):
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is None:
continue
......
......@@ -465,8 +465,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale or zero point.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
......
......@@ -140,8 +140,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
......
......@@ -238,8 +238,8 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
# Remapping the name FP8 kv-scale or zero point.
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
......
......@@ -661,7 +661,7 @@ class NemotronHModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
......
......@@ -342,7 +342,7 @@ class DeciModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
if "scale" in name or "zero_point" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
......
......@@ -620,6 +620,7 @@ class AttentionImpl(ABC, Generic[T]):
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
supports_per_head_quant_scales: bool = False
dcp_world_size: int
dcp_rank: int
......
......@@ -576,6 +576,11 @@ class FlashAttentionImpl(AttentionImpl):
)
self.supports_quant_query_input = True
self.supports_per_head_quant_scales = (
self.vllm_flash_attn_version >= 3
if self.vllm_flash_attn_version is not None
else False
)
def forward(
self,
......@@ -691,6 +696,10 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
......@@ -700,9 +709,9 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output
else:
......@@ -728,9 +737,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment