Unverified Commit d4bf5a85 authored by kk's avatar kk Committed by GitHub
Browse files

Support OCP MXFP4 quantization on AMD GPUs (#8255)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 7cb20754
......@@ -401,6 +401,8 @@ class ModelConfig:
"fbgemm_fp8",
"w8a8_fp8",
"petit_nvfp4",
"quark",
"mxfp4",
]
optimized_quantization_methods = [
"fp8",
......
......@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from sglang.srt.utils import mxfp_supported
is_mxfp_supported = mxfp_supported()
if is_mxfp_supported:
from sglang.srt.layers.quantization.fp4 import MxFp4Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import (
GPTQConfig,
......@@ -84,7 +90,13 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
}
if is_mxfp_supported:
BASE_QUANTIZATION_METHODS.update(
{
"quark": MxFp4Config,
"mxfp4": MxFp4Config,
}
)
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from .quark_scheme import QuarkScheme
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
__all__ = ["QuarkScheme", "QuarkW4A4MXFP4"]
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["QuarkScheme"]
class QuarkScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by Quark.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Optional
import aiter
import torch
import torch.nn.functional as F
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from aiter.utility import dtypes
from aiter.utility.fp4_utils import e8m0_shuffle
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme
from sglang.srt.utils import get_bool_env_var
__all__ = ["QuarkW4A4MXFP4"]
OCP_MX_BLOCK_SIZE = 32
class QuarkW4A4MXFP4(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
return
# for aiter implement
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
# layer.weight = torch.nn.Parameter(wshuffle,
# requires_grad=False)
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
# requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out_dtype = x.dtype
# M = x.shape[0]
# N = layer.weight.shape[0]
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
# x, x_scales_shuffle = quant_func(x, shuffle=True)
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
# return out[:M]
# triton implement
x_q, x_s = dynamic_mxfp4_quant(x)
y = torch.empty(
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
)
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
return out
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, Optional
def deep_compare(dict1: Any, dict2: Any) -> bool:
if type(dict1) is not type(dict2):
return False
if isinstance(dict1, dict):
if dict1.keys() != dict2.keys():
return False
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
elif isinstance(dict1, list):
return set(dict1) == set(dict2)
else:
return dict1 == dict2
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore
)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(
f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme."
)
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(
layer_name=layer_name, targets=ignore
)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def _is_equal_or_regex_match(
value: str, target: str, check_contains: bool = False
) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
......@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return None
return remapped_name
quark_scale_names = {
".q_proj.output_scale": ".attn.q_scale",
".k_proj.output_scale": ".attn.k_scale",
".v_proj.output_scale": ".attn.v_scale",
"self_attn.prob_output_scale": ".attn.prob_scale",
}
for quark_scale_name, sglang_scale_name in quark_scale_names.items():
if name.endswith(quark_scale_name):
return name.replace(quark_scale_name, sglang_scale_name)
# If there were no matches, return the untouched param name
return name
......
......@@ -2061,6 +2061,8 @@ class DeepseekV2Model(nn.Module):
class DeepseekV2ForCausalLM(nn.Module):
# for quark model load
packed_modules_mapping = {}
def __init__(
self,
......@@ -2069,6 +2071,18 @@ class DeepseekV2ForCausalLM(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# for quark model load
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
......
......@@ -813,6 +813,7 @@ class ServerArgs:
"moe_wna16",
"qoq",
"w4afp8",
"mxfp4",
],
help="The quantization method.",
)
......
......@@ -2832,6 +2832,17 @@ def parse_module_path(module_path, function_name, create_dummy):
return final_module, None
def mxfp_supported():
"""
Returns whether the current platform supports MX types.
"""
if torch.version.hip:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ["gfx95"])
else:
return False
# LoRA-related constants and utilities
SUPPORTED_LORA_TARGET_MODULES = [
"q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment