Commit ce363e89 authored by yiqa's avatar yiqa
Browse files

Merge remote-tracking branch 'origin/v0.5.4_dev_yiqa' into v0.5.4_dev_yiqa

# Conflicts:
#	python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
parents 20241efa a34b0d3d
......@@ -616,6 +616,7 @@ class ModelConfig:
"mxfp4",
"slimquant_w4a8_marlin",
"w8a8_int8",
"slimquant_marlin",
]
optimized_quantization_methods = [
"fp8",
......@@ -636,6 +637,7 @@ class ModelConfig:
"w4afp8",
"petit_nvfp4",
"slimquant_w4a8_marlin",
"slimquant_marlin",
]
compatible_quantization_methods = {
"modelopt_fp4": ["modelopt"],
......
......@@ -167,6 +167,14 @@ class Envs:
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD = EnvBool(False)
SGLANG_USE_OPT_CAT = EnvBool(False)
SGLANG_USE_FUSED_RMS_QUANT = EnvBool(False)
SGLANG_USE_FUSED_SILU_MUL_QUANT = EnvBool(False)
# Quantization
SGLANG_INT4_WEIGHT = EnvBool(False)
SGLANG_CPU_QUANTIZATION = EnvBool(False)
......
from __future__ import annotations
import warnings
import torch
from sglang.srt.utils import get_bool_env_var
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _USE_OPT_CAT:
try:
from lightop import ds_cat # type: ignore
except ImportError: # pragma: no cover
ds_cat = None
warnings.warn(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else:
ds_cat = None
def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
mode = 0
if dim!=0 :
ds_cat( A, B, C, mode)
return C
assert False, "not support"
\ No newline at end of file
......@@ -44,6 +44,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant = _is_hip and get_bool_env_var(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
if _use_fused_rms_quant:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if _use_fused_silu_mul_quant:
try:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_quant error: {e}")
logger = logging.getLogger(__name__)
......@@ -1358,7 +1370,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, skip_all_reduce=False):
def forward(self, input_, skip_all_reduce=False, use_fused_silu_mul_quant=False):
if self.input_is_parallel:
input_parallel = input_
else:
......@@ -1372,9 +1384,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args
)
sm.tag(output_parallel)
else:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
......
......@@ -42,6 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.environ import envs
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import (
cpu_has_amx_support,
......@@ -58,6 +59,7 @@ if is_flashinfer_available():
_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_user_lightop_moe_sum_mul_add = get_bool_env_var("SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD")
# Try to import FP4 TRTLLM function if flashinfer is available
......@@ -221,6 +223,7 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.use_presharded_weights = use_presharded_weights
# self.global_num_experts = self.num_experts
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
......@@ -877,9 +880,21 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput = None, shared_output: torch.Tensor = None, **kwargs):
origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None
if _user_lightop_moe_sum_mul_add:
final_hidden_states = self.quant_method.apply_with_shared_output(
layer=self,
x=hidden_states,
activation=getattr(self, 'moe_runner_config', None) and self.moe_runner_config.activation or "silu",
shared_output=shared_output,
topk_output=topk_output,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
......
......@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
_is_mxfp_supported = mxfp_supported()
......@@ -84,7 +85,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
"slimquant_w4a8_marlin": SlimQuantW4A8Int8MarlinConfig,
"slimquant_marlin": SlimQuantCompressedTensorsMarlinConfig,
}
......
......@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
)
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
logger = logging.getLogger(__name__)
......@@ -636,3 +637,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.quantization import QuantizationArgs
import logging
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsKVCacheMethod
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin import CompressedTensorsMarlinMoEMethod
from sglang.srt.layers.quantization.compressed_tensors.utils import (
should_ignore_layer)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
import os
# if TYPE_CHECKING:
# from vllm.model_executor.models.utils import WeightsMapper
logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None,
):
super().__init__(
target_scheme_map,
ignore,
quant_format,
sparsity_scheme_map,
sparsity_ignore_list,
kv_cache_scheme,
config,
packed_modules_mapping,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "compressed-tensors" \
and user_quant == "slimquant_marlin":
return cls.get_name()
return None
@classmethod
def get_name(cls) -> str:
return "slimquant_marlin"
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization.
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import enum
from enum import Enum
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
import logging
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.utils import set_weight_attrs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = logging.getLogger(__name__)
__all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod",
]
def get_w8a8_int8_marlin_weights(
weight,
k_tile=64):
# 7168, 512
weight = weight.T
size_k, size_n = weight.shape
assert size_k // k_tile
weight = weight.reshape(size_k // k_tile, k_tile, size_n)
weight = weight.transpose(1, 2)
weight = weight.reshape(size_k // k_tile, size_n * k_tile)
return weight
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else:
raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
# def apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# renormalize: bool,
# use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# expert_load_view: Optional[torch.Tensor] = None,
# logical_to_physical_map: Optional[torch.Tensor] = None,
# logical_replica_count: Optional[torch.Tensor] = None,
# shared_output: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for "
# "`CompressedTensorsW8A8Int8MoEMethod` yet.")
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate,
# e_score_correction_bias=e_score_correction_bias)
# return fused_experts_impl_int8_marlin(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# inplace=True,
# activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
# use_int8_w8a8=True,
# per_channel_quant=True,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=False,
# shared_output=shared_output,
# routed_scaling_factor=routed_scaling_factor)
def apply(
self,
layer: torch.nn.Module,
dispatch_output,
) :
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
output = fused_experts_impl_int8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.moe_runner_config.activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
)
return StandardCombineInput(hidden_states=output)
\ No newline at end of file
......@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda
from lmslim import quant_ops
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
......@@ -168,6 +169,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
return int8_scaled_mm(
return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
......@@ -19,6 +19,9 @@ from vllm.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os
from sglang.srt.utils import get_bool_env_var
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
......@@ -163,13 +166,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
):
# if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
# assert len(input_quant_args) == 2
# x_q, x_scale = input_quant_args
# elif envs.USE_FUSED_SILU_MUL_QUANT and silu_quant_args is not None:
# x_q, x_scale = silu_quant_args
# else:
x_q, x_scale = per_token_quant_int8(x)
if _use_fused_rms_quant and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
elif _use_fused_silu_mul_quant and silu_quant_args is not None:
x_q, x_scale = silu_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
m=x_q.shape[0]
......
......@@ -95,7 +95,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin":
and user_quant in ("slimquant_w4a8_marlin", "slimquant_marlin"):
return cls.get_name()
return None
def get_quant_method(
......@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable()
def apply(
self,
......@@ -252,6 +252,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe=False,
)
return StandardCombineInput(hidden_states=output)
def apply_with_shared_output(
self,
layer: torch.nn.Module,
x: torch.Tensor,
activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
topk_output=None,
) -> torch.Tensor:
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
global_num_experts=layer.moe_runner_config.num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
)
# def _apply(
# self,
# layer: torch.nn.Module,
......@@ -273,7 +306,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -317,11 +350,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
#
def apply_ep(self,
def apply_ep(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -361,8 +391,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
......
......@@ -141,6 +141,7 @@ from sglang.srt.utils import (
make_layers,
use_intel_amx_backend,
)
from sglang.srt.layers.attention.lightop_concat import concat_decode_opt
_is_hip = is_hip()
_is_cuda = is_cuda()
......@@ -151,8 +152,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_device_sm = get_device_sm()
_is_gfx95_supported = is_gfx95_supported()
_user_lightop_moe_sum_mul_add = get_bool_env_var("SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
_use_opt_cat_decode = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _use_aiter_gfx95:
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
......@@ -456,10 +459,13 @@ class DeepseekV2MLP(nn.Module):
x = (x, None, y)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
if _use_fused_silu_mul_quant:
x, _ = self.down_proj(gate_up, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x
......@@ -757,49 +763,58 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
if _user_lightop_moe_sum_mul_add:
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output, shared_output=shared_output)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
if self._fuse_shared_experts_inside_sbo:
shared_output = None
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
if self._fuse_shared_experts_inside_sbo:
shared_output = None
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
......@@ -1696,7 +1711,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rotary_emb.is_neox_style,
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
if _use_opt_cat_decode and q_nope_out.shape[0] < 1024:
q = concat_decode_opt(q_nope_out, q_pe, dim=2)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(
......
......@@ -94,6 +94,7 @@ QUANTIZATION_CHOICES = [
"mxfp4",
"compressed-tensors", # for Ktransformers
"slimquant_w4a8_marlin",
"slimquant_marlin",
]
ATTENTION_BACKEND_CHOICES = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment