Unverified Commit 2695ab05 authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

Fix loading KV quantization scale; Enable modelopt kv cache (#4686)


Co-authored-by: default avatarqingquansong <ustcsqq@gmail.com>
parent 88d6fd9a
......@@ -239,7 +239,7 @@ class ModelConfig:
# check if is modelopt model -- modelopt doesn't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
is_local = os.path.isdir(self.model_path)
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
from huggingface_hub import HfApi
......
......@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
......@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None
else (-1, -1)
)
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention
# Check if we should use local attention
......@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
......@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
)
causal = not layer.is_cross_attention
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla:
# Do multi-head attention
......@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
......@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......
......@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not (
model_runner.sliding_window_size is not None
......@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
......@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = prefill_wrapper_paged.forward(
......@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
......@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
......@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = decode_wrapper.forward(
......@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
......
......@@ -8,6 +8,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_hip
_is_hip = is_hip()
......@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Quant method that adds `k_scale` and `v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
......@@ -36,8 +37,12 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.k_scale = torch.nn.Parameter(
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
layer.v_scale = torch.nn.Parameter(
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
)
@classmethod
def is_fp8_fnuz(cls) -> bool:
......@@ -47,52 +52,38 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError(
"Only support per-tensor scaling factor " "for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
logger.warning(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del layer.k_scale
del layer.v_scale
def process_weights_after_loading(self, layer: RadixAttention) -> None:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = 1.0
v_scale = 1.0
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if _is_hip and self.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(v_scale, float):
raise ValueError(
"Only support per-tensor scaling factor " "for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer.k_scale.copy_(k_scale)
layer.v_scale.copy_(v_scale)
layer.k_scale_float = k_scale
layer.v_scale_float = v_scale
......@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
......@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
# Initialize logger for the module
logger = logging.getLogger(__name__)
......@@ -33,12 +33,19 @@ ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
......@@ -63,6 +70,12 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo"
)
exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
"exclude_modules"
)
if "FP8" not in quant_method:
raise ValueError(
......@@ -70,15 +83,23 @@ class ModelOptFp8Config(QuantizationConfig):
"Check the `hf_quant_config.json` file for your model's configuration."
)
return cls(is_checkpoint_fp8_serialized=True)
return cls(
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any(
module in prefix for module in self.exclude_modules
):
return None
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
return None
......
......@@ -13,8 +13,12 @@
# ==============================================================================
"""Radix attention."""
from typing import Optional
from torch import nn
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -34,6 +38,7 @@ class RadixAttention(nn.Module):
v_head_dim: int = -1,
sliding_window_size: int = -1,
is_cross_attention: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_irope: bool = False,
):
......@@ -49,9 +54,16 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
self.use_irope = use_irope
self.k_scale = None
self.v_scale = None
self.use_irope = use_irope
self.k_scale_float = None
self.v_scale_float = None
self.quant_method = None
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
if self.quant_method is not None:
self.quant_method.create_weights(self)
def forward(
self,
......
......@@ -178,6 +178,7 @@ class BaiChuanAttention(nn.Module):
scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
else:
......@@ -194,6 +195,7 @@ class BaiChuanAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -113,6 +113,7 @@ class GLMAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -204,6 +204,7 @@ class CohereAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
if self.use_qk_norm:
......
......@@ -249,6 +249,7 @@ class DbrxAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -255,6 +255,7 @@ class DeepseekAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -489,6 +489,7 @@ class DeepseekV2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......@@ -669,6 +670,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
quant_config=quant_config,
prefix=add_prefix("attn_mqa", prefix),
)
......@@ -679,6 +681,7 @@ class DeepseekV2AttentionMLA(nn.Module):
num_kv_heads=self.num_local_heads,
layer_id=layer_id,
v_head_dim=self.v_head_dim,
quant_config=quant_config,
prefix=add_prefix("attn_mha", prefix),
)
......
......@@ -155,6 +155,7 @@ class ExaoneAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
)
def forward(
......
......@@ -137,6 +137,7 @@ class GemmaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -163,6 +163,7 @@ class Gemma2Attention(nn.Module):
if use_sliding_window
else None
),
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -193,6 +193,7 @@ class Gemma3Attention(nn.Module):
# Module must also define `get_attention_sliding_window_size` to correctly initialize
# attention backend in `ForwardBatch`.
sliding_window_size=self.sliding_window,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -78,6 +78,7 @@ class GPT2Attention(nn.Module):
scaling=self.scale,
num_kv_heads=total_num_heads,
layer_id=layer_id,
quant_config=quant_config,
)
def forward(
......
......@@ -87,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
scaling=self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -158,6 +158,7 @@ class GraniteAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......
......@@ -215,6 +215,7 @@ class Grok1Attention(nn.Module):
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=logit_cap,
quant_config=quant_config,
)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment