Unverified Commit ec8ab9d2 authored by Douglas Lehr's avatar Douglas Lehr Committed by GitHub
Browse files

[ROCm] Add dynamic mxfp4 quantization for DeepSeek V2 projection layers (#34157)


Signed-off-by: default avatarDoug Lehr <douglehr@amd.com>
Signed-off-by: default avatarDouglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Co-authored-by: default avatarDoug Lehr <douglehr@amd.com>
Co-authored-by: default avatarRohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
parent 05972ea7
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
) )
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
...@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig): ...@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.pack_method = pack_method self.pack_method = pack_method
self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
quant_config = getattr(self.hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod": def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self) return QuarkLinearMethod(self)
...@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig): ...@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer( if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
): ):
return UnquantizedLinearMethod() if (
"self_attn" not in prefix # only quantize attention projections
or not getattr(self, "dynamic_mxfp4_quant", False)
or not isinstance(layer, LinearBase) # Ignore other methods
):
return UnquantizedLinearMethod()
scheme = self.get_scheme(
layer=layer,
layer_name=prefix,
dynamic_mxfp4_quant=True,
)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme layer.scheme = scheme
...@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig): ...@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
) )
return global_quant_config return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": def _get_scheme_from_config(
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"): if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError( raise NotImplementedError(
"Currently, Quark models with output_tensors " "Currently, Quark models with output_tensors "
...@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig): ...@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
input_symmetric=input_config.get("symmetric"), input_symmetric=input_config.get("symmetric"),
) )
elif self._is_w_ocp_mx_a_x(weight_config, input_config): elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config) return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
raise NotImplementedError( raise NotImplementedError(
"No quark compatible scheme was found. " "No quark compatible scheme was found. "
...@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig): ...@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
f"Input config: {input_config}" f"Input config: {input_config}"
) )
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": def get_scheme(
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer) layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme # Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config) scheme = self._get_scheme_from_config(
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
# Raise error if device does not support the scheme # Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace) # (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability()) self._check_scheme_supported(scheme.get_min_capability())
......
...@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ...@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE, OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme, OCP_MX_Scheme,
) )
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PackedvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme from .quark_scheme import QuarkScheme
...@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError): ...@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
class QuarkOCP_MX(QuarkScheme): class QuarkOCP_MX(QuarkScheme):
def __init__( def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
dynamic_mxfp4_quant: bool = False,
): ):
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group" self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
...@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
layer.weight_scale.data, requires_grad=False layer.weight_scale.data, requires_grad=False
) )
else: else:
if self.rocm_use_aiter_fp4_asm_gemm: if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale # shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape sm, sn = weight_scale_shuffle.shape
...@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
output_size_per_partition = sum(output_partition_sizes) if self.dynamic_mxfp4_quant:
layer.logical_widths = output_partition_sizes weight = ModelWeightParameter(
data=torch.empty(
# WEIGHT sum(output_partition_sizes),
weight = PackedvLLMParameter( input_size_per_partition,
data=torch.empty( dtype=params_dtype,
output_size_per_partition, ),
self.get_packed_dim(input_size_per_partition, self.weight_dtype), input_dim=1,
dtype=torch.uint8, output_dim=0,
), weight_loader=weight_loader,
input_dim=1, )
output_dim=0,
packed_dim=1, layer.register_parameter("weight", weight)
packed_factor=self.packed_factor, set_weight_attrs(weight, kwargs)
weight_loader=weight_loader, else:
) output_size_per_partition = sum(output_partition_sizes)
layer.register_parameter("weight", weight) layer.logical_widths = output_partition_sizes
# WEIGHT SCALE # WEIGHT
weight_scale = GroupQuantScaleParameter( weight = PackedvLLMParameter(
data=torch.empty( data=torch.empty(
output_size_per_partition, output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE, self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8, dtype=torch.uint8,
), ),
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=weight_loader, packed_dim=1,
) packed_factor=self.packed_factor,
layer.register_parameter("weight_scale", weight_scale) weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights( def apply_weights(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment