Unverified Commit c66b2c9c authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Add support for nvidia modelopt fp8 kv cache (#3223)

parent 20b765a2
...@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional ...@@ -5,12 +5,14 @@ from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, apply_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -70,7 +72,13 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -171,3 +179,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported, cutlass_fp8_supported=self.cutlass_fp8_supported,
) )
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)
...@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: ...@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return remapped_name return remapped_name
possible_scale_names = [".k_scale", ".v_scale"] possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
for scale_name in possible_scale_names: for scale_name in possible_scale_names:
if name.endswith(scale_name): if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}") # Check and remap the name based on modelopt scale names
if any(
modelopt_scale_name in name
for modelopt_scale_name in modelopt_scale_names
):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}",
)
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict: if remapped_name not in params_dict:
print_warning_once( print_warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), " f"Found {scale_name} in the checkpoint (e.g. {name}), "
......
...@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -47,6 +47,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name,
) )
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module): ...@@ -457,6 +458,11 @@ class LlamaForCausalLM(nn.Module):
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
# Handle FP8 kv-scale remapping
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
......
import unittest
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config,
ModelOptFp8KVCacheMethod,
)
class TestModelOptFp8KVCacheMethod(unittest.TestCase):
def test_kv_cache_method_initialization(self):
"""Test that ModelOptFp8KVCacheMethod can be instantiated and
inherits from BaseKVCacheMethod."""
# Create a ModelOptFp8Config object
quant_config = ModelOptFp8Config(is_checkpoint_fp8_serialized=True)
# Instantiate the KV cache method
kv_cache_method = ModelOptFp8KVCacheMethod(quant_config)
# Check inheritance
self.assertIsInstance(kv_cache_method, BaseKVCacheMethod)
# Check that the quant_config is stored
self.assertEqual(kv_cache_method.quant_config, quant_config)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment