Unverified Commit d918ab79 authored by Haohui Mai's avatar Haohui Mai Committed by GitHub
Browse files

Support NVFP4 quantized dense models on AMD CDNA2/CDNA3 GPUs (#7302)


Co-authored-by: default avatarHAI <hixiao@gmail.com>
Co-authored-by: default avatarSai Enduri <saimanas.enduri@amd.com>
parent 3964b352
...@@ -79,6 +79,7 @@ blackwell = [ ...@@ -79,6 +79,7 @@ blackwell = [
srt_hip = [ srt_hip = [
"sglang[runtime_common]", "sglang[runtime_common]",
"torch", "torch",
"petit_kernel",
] ]
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,
......
...@@ -391,6 +391,7 @@ class ModelConfig: ...@@ -391,6 +391,7 @@ class ModelConfig:
"compressed-tensors", "compressed-tensors",
"fbgemm_fp8", "fbgemm_fp8",
"w8a8_fp8", "w8a8_fp8",
"petit_nvfp4",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",
...@@ -408,9 +409,11 @@ class ModelConfig: ...@@ -408,9 +409,11 @@ class ModelConfig:
"moe_wna16", "moe_wna16",
"qoq", "qoq",
"w4afp8", "w4afp8",
"petit_nvfp4",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"], "modelopt_fp4": ["modelopt"],
"petit_nvfp4": ["modelopt"],
"w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
} }
......
...@@ -53,6 +53,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -53,6 +53,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"ModelOptFp8LinearMethod", "ModelOptFp8LinearMethod",
"ModelOptFp4LinearMethod", "ModelOptFp4LinearMethod",
"IPEXAWQLinearMethod", "IPEXAWQLinearMethod",
"PetitNvFp4LinearMethod",
] ]
_is_cpu = is_cpu() _is_cpu = is_cpu()
......
...@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config, ModelOptFp8Config,
) )
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import get_linear_quant_method from sglang.srt.layers.quantization.utils import get_linear_quant_method
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
...@@ -76,6 +77,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -76,6 +77,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"qoq": QoQConfig, "qoq": QoQConfig,
"w4afp8": W4AFp8Config, "w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
} }
# VLLM-dependent quantization methods # VLLM-dependent quantization methods
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
from typing import Any, Callable, Dict, List, Optional
import regex as re
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.petit_utils import (
apply_petit_nvfp4_linear,
prepare_nvfp4_layer_for_petit,
verify_petit_nvfp4_supported,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
# Initialize logger for the module
logger = logging.getLogger(__name__)
# Configuration class to support the NVFP4 quantized model generated by the ModelOpt quantization tool
class PetitNvFp4Config(QuantizationConfig):
"""Config class for Petit FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "petit_nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Petit supports the gfx90a and gfx942 GPUs
return 90
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "PetitNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
group_size = quant_config.get("group_size", None)
verify_petit_nvfp4_supported(quant_method, group_size)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
exclude_modules = quant_config.get("exclude_modules", None)
if not (group_size and kv_cache_quant_algo and (exclude_modules is not None)):
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return cls(
is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo,
group_size,
exclude_modules,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_petit_nvfp4_compatible(hf_quant_cfg)
if can_convert:
return cls.get_name()
return None
@classmethod
def is_petit_nvfp4_compatible(cls, quant_config: Dict[str, Any]) -> bool:
quant_method = quant_config.get("quant_method", "").lower()
return quant_method == "modelopt"
def is_layer_excluded(self, prefix: str, exclude_modules: list):
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return PetitNvFp4LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class PetitNvFp4LinearMethod(LinearMethodBase):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: PetitNvFp4Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % 16 != 0:
raise ValueError(
"Unsupported model when in features size is " "not multiple of 16"
)
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_scale_2 = layer.input_scale.max().to(torch.float32)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(
layer.input_scale * layer.weight_scale_2, requires_grad=False
)
prepare_nvfp4_layer_for_petit(layer)
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_petit_nvfp4_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
from typing import Optional
import torch
try:
from petit_kernel import mul_nvfp4_a16, process_nvfp4_scales, repack_nvfp4
except ImportError:
def _check_petit_nvfp4_supported(
quant_method: str, group_size: Optional[int]
) -> tuple[bool, Optional[str]]:
return (
False,
"Petit is not installed. Please install it with `pip install petit-kernel`.",
)
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
raise ValueError(
"Petit is not installed. Please install it with `pip install petit-kernel`."
)
def apply_petit_nvfp4_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise ValueError(
"Petit is not installed. Please install it with `pip install petit-kernel`."
)
def _check_petit_nvfp4_supported(
quant_method: str, group_size: Optional[int]
) -> tuple[bool, Optional[str]]:
if quant_method != "NVFP4":
return (
False,
"Petit currently only supports: NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.",
)
if group_size is not None and group_size != 16:
return (
False,
"Petit currently only supports: group_size=16" " quantizations.",
)
return (True, None)
def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None:
supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size)
if not supported:
raise ValueError(error_msg)
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
# Repack weights to petit format
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
qweight = layer.weight.view(torch.int32).contiguous()
petit_qweight = repack_nvfp4(qweight, size_n=part_size_n, size_k=part_size_k)
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
# Permute scales
weight_scale = process_nvfp4_scales(
scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
return
def apply_petit_nvfp4_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
# TODO: Use auto-tuning to find the performant solution_id
output = mul_nvfp4_a16(
a=reshaped_x,
b=weight,
s=weight_scale,
global_scale=weight_scale_2,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
solution_id=-1,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
...@@ -766,6 +766,7 @@ class ServerArgs: ...@@ -766,6 +766,7 @@ class ServerArgs:
"gguf", "gguf",
"modelopt", "modelopt",
"modelopt_fp4", "modelopt_fp4",
"petit_nvfp4",
"w8a8_int8", "w8a8_int8",
"w8a8_fp8", "w8a8_fp8",
"moe_wna16", "moe_wna16",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment