Commit 944a8aab authored by zhuwenwen's avatar zhuwenwen
Browse files

feat: w8a8_marlin 接入,通过-q slimquant_marlin开启,优化w4a8_marlin代码

parent 1f526c04
...@@ -892,8 +892,8 @@ class ModelConfig: ...@@ -892,8 +892,8 @@ class ModelConfig:
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "slimquant_w4a8",
"slimquant_w4a8","slimquant_w4a8_marlin" "slimquant_w4a8_marlin", "slimquant_compressed_tensors_marlin"
] ]
if self.quantization is not None: if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods, self.quantization = cast(me_quant.QuantizationMethods,
...@@ -920,7 +920,8 @@ class ModelConfig: ...@@ -920,7 +920,8 @@ class ModelConfig:
"awq_marlin", "awq_marlin",
"ipex", "ipex",
"moe_wna16", "moe_wna16",
"slimquant_w4a8_marlin" "slimquant_w4a8_marlin",
"slimquant_compressed_tensors_marlin"
] ]
quantization_methods = [ quantization_methods = [
q for q in supported_quantization if q not in overrides q for q in supported_quantization if q not in overrides
......
...@@ -38,7 +38,8 @@ QuantizationMethods = Literal[ ...@@ -38,7 +38,8 @@ QuantizationMethods = Literal[
"rtn", "rtn",
"blockwise_int8", "blockwise_int8",
"slimquant_w4a8", "slimquant_w4a8",
"slimquant_w4a8_marlin" "slimquant_w4a8_marlin",
"slimquant_compressed_tensors_marlin",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -97,6 +98,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -97,6 +98,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .bitsandbytes import BitsAndBytesConfig from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501 from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig) CompressedTensorsConfig)
from .compressed_tensors.compressed_tensors_marlin import (
SlimQuantCompressedTensorsMarlinConfig)
from .deepspeedfp import DeepSpeedFPConfig from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config from .fbgemm_fp8 import FBGEMMFp8Config
...@@ -154,6 +157,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -154,6 +157,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"slimquant_w4a8":SlimQuantW4A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
"slimquant_compressed_tensors_marlin":SlimQuantCompressedTensorsMarlinConfig,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationArgs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsKVCacheMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin import (
CompressedTensorsMarlinMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
import os
from vllm import _custom_ops as ops
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
):
super().__init__(
target_scheme_map,
ignore,
quant_format,
sparsity_scheme_map,
sparsity_ignore_list,
kv_cache_scheme,
config
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if hf_quant_cfg.get("quant_method") == "compressed-tensors" \
and user_quant == "slimquant_marlin":
return cls.get_name()
return None
@classmethod
def get_name(cls) -> QuantizationMethods:
return "slimquant_compressed_tensors_marlin"
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None
\ No newline at end of file
...@@ -90,8 +90,6 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -90,8 +90,6 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config) return CompressedTensorsW8A8Int8MoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
from enum import Enum
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
from vllm.logger import init_logger
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights)
try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
__all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod",
]
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else:
raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias)
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
\ No newline at end of file
...@@ -110,7 +110,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -110,7 +110,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \ if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin": and user_quant in ("slimquant_w4a8_marlin", "slimquant_marlin"):
return cls.get_name() return cls.get_name()
return None return None
def get_quant_method( def get_quant_method(
...@@ -347,7 +347,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -347,7 +347,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
......
...@@ -25,6 +25,21 @@ USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() ...@@ -25,6 +25,21 @@ USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and torch.__version__[0:3] >= "2.7" and torch.__version__[0:3] >= "2.7"
and current_platform.has_device_capability(94)) and current_platform.has_device_capability(94))
def get_w8a8_int8_marlin_weights(
weight,
k_tile=64):
# 7168, 512
weight = weight.T
size_k, size_n = weight.shape
assert size_k // k_tile
weight = weight.reshape(size_k // k_tile, k_tile, size_n)
weight = weight.transpose(1, 2)
weight = weight.reshape(size_k // k_tile, size_n * k_tile)
return weight
def sparse_cutlass_supported() -> bool: def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return False return False
......
...@@ -185,7 +185,8 @@ class RocmPlatform(Platform): ...@@ -185,7 +185,8 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin","slimquant_w4a8_marlin" "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin",
"slimquant_w4a8_marlin","slimquant_compressed_tensors_marlin"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment