Commit 25796d05 authored by maxiao1's avatar maxiao1
Browse files

适配w8a8_marlin

parent ac7dcc2d
...@@ -616,6 +616,7 @@ class ModelConfig: ...@@ -616,6 +616,7 @@ class ModelConfig:
"mxfp4", "mxfp4",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"w8a8_int8", "w8a8_int8",
"slimquant_marlin",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",
...@@ -636,6 +637,7 @@ class ModelConfig: ...@@ -636,6 +637,7 @@ class ModelConfig:
"w4afp8", "w4afp8",
"petit_nvfp4", "petit_nvfp4",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"], "modelopt_fp4": ["modelopt"],
......
...@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config ...@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
_is_mxfp_supported = mxfp_supported() _is_mxfp_supported = mxfp_supported()
...@@ -84,7 +85,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -84,7 +85,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config, "w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config, "petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, "slimquant_w4a8_marlin": SlimQuantW4A8Int8MarlinConfig,
"slimquant_marlin": SlimQuantCompressedTensorsMarlinConfig,
} }
......
...@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( ...@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
) )
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -636,3 +637,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -636,3 +637,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias) return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationArgs
import logging
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsKVCacheMethod
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin import CompressedTensorsMarlinMoEMethod
from sglang.srt.layers.quantization.compressed_tensors.utils import (
should_ignore_layer)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
import os
# if TYPE_CHECKING:
# from vllm.model_executor.models.utils import WeightsMapper
logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None,
):
super().__init__(
target_scheme_map,
ignore,
quant_format,
sparsity_scheme_map,
sparsity_ignore_list,
kv_cache_scheme,
config,
packed_modules_mapping,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "compressed-tensors" \
and user_quant == "slimquant_marlin":
return cls.get_name()
return None
@classmethod
def get_name(cls) -> str:
return "slimquant_marlin"
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization.
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import enum
from enum import Enum
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
import logging
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import set_weight_attrs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = logging.getLogger(__name__)
__all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod",
]
def get_w8a8_int8_marlin_weights(
weight,
k_tile=64):
# 7168, 512
weight = weight.T
size_k, size_n = weight.shape
assert size_k // k_tile
weight = weight.reshape(size_k // k_tile, k_tile, size_n)
weight = weight.transpose(1, 2)
weight = weight.reshape(size_k // k_tile, size_n * k_tile)
return weight
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else:
raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
# def apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# renormalize: bool,
# use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# expert_load_view: Optional[torch.Tensor] = None,
# logical_to_physical_map: Optional[torch.Tensor] = None,
# logical_replica_count: Optional[torch.Tensor] = None,
# shared_output: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for "
# "`CompressedTensorsW8A8Int8MoEMethod` yet.")
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate,
# e_score_correction_bias=e_score_correction_bias)
# return fused_experts_impl_int8_marlin(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# inplace=True,
# activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
# use_int8_w8a8=True,
# per_channel_quant=True,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=False,
# shared_output=shared_output,
# routed_scaling_factor=routed_scaling_factor)
def apply(
self,
layer: torch.nn.Module,
dispatch_output,
) :
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
output = fused_experts_impl_int8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.moe_runner_config.activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
)
return StandardCombineInput(hidden_states=output)
\ No newline at end of file
...@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import ( ...@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
from sglang.srt.layers.quantization.compressed_tensors.schemes import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 # from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
from lmslim import quant_ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
...@@ -168,6 +169,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -168,6 +169,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# TODO: add cutlass_scaled_mm_azp support # TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
return int8_scaled_mm( return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
) )
...@@ -95,7 +95,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -95,7 +95,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \ if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin": and user_quant in ("slimquant_w4a8_marlin", "slimquant_marlin"):
return cls.get_name() return cls.get_name()
return None return None
def get_quant_method( def get_quant_method(
......
...@@ -94,6 +94,7 @@ QUANTIZATION_CHOICES = [ ...@@ -94,6 +94,7 @@ QUANTIZATION_CHOICES = [
"mxfp4", "mxfp4",
"compressed-tensors", # for Ktransformers "compressed-tensors", # for Ktransformers
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
] ]
ATTENTION_BACKEND_CHOICES = [ ATTENTION_BACKEND_CHOICES = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment