Unverified Commit 49b87774 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Refactor: move all quantization-related code to `srt/layer/quantization` (#7989)

parent 02404a1e
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
from __future__ import annotations
import itertools import itertools
import logging import logging
from abc import abstractmethod from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -17,7 +17,6 @@ from sglang.srt.distributed import ( ...@@ -17,7 +17,6 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import ( ...@@ -27,17 +26,14 @@ from sglang.srt.layers.parameter import (
RowvLLMParameter, RowvLLMParameter,
_ColumnvLLMParameter, _ColumnvLLMParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_npu,
set_weight_attrs,
use_intel_amx_backend,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,7 +55,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -59,7 +55,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"IPEXAWQLinearMethod", "IPEXAWQLinearMethod",
] ]
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_npu = is_npu() _is_npu = is_npu()
...@@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -110,91 +105,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight return param[shard_id], loaded_weight
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["weight"])
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)
return F.linear(x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Base linear layer. """Base linear layer.
...@@ -310,7 +220,7 @@ class ReplicatedLinear(LinearBase): ...@@ -310,7 +220,7 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size() assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
...@@ -845,7 +755,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -845,7 +755,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "", prefix: str = "",
tp_rank: Optional[int] = None, tp_rank: Optional[int] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
......
...@@ -27,22 +27,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -27,22 +27,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
scaled_fp8_quant,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
) )
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -53,7 +51,6 @@ from sglang.srt.utils import ( ...@@ -53,7 +51,6 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
is_npu, is_npu,
set_weight_attrs,
) )
_is_hip = is_hip() _is_hip = is_hip()
...@@ -904,324 +901,6 @@ class EPMoE(torch.nn.Module): ...@@ -904,324 +901,6 @@ class EPMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
layer: torch.nn.Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w2_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
class Fp8EPMoEMethod(Fp8MoEMethod):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts_per_partition,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
layer.w13_weight_scale = torch.nn.Parameter(
torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False,
)
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(
w13_weight, requires_grad=False
)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
class DeepEPMoE(EPMoE): class DeepEPMoE(EPMoE):
""" """
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
......
...@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -9,7 +9,6 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
...@@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]: ...@@ -31,11 +30,9 @@ def get_config() -> Optional[Dict[str, Any]]:
__all__ = [ __all__ = [
"FusedMoE", "FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported", "FusedMoeWeightScaleSupported",
"override_config", "override_config",
"get_config", "get_config",
"fused_moe",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"moe_align_block_size", "moe_align_block_size",
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import importlib import logging
from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import ( from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
use_intel_amx_backend,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else:
fused_experts = None # type: ignore
import logging
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum): ...@@ -66,333 +34,6 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block" BLOCK = "block"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
) -> torch.Tensor:
raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter(
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
intermediate_size,
)
if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
if self.use_triton_kernels:
return triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace # See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
else:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -553,7 +194,7 @@ class FusedMoE(torch.nn.Module): ...@@ -553,7 +194,7 @@ class FusedMoE(torch.nn.Module):
shard_dim: int, shard_dim: int,
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_id: str, shard_id: str,
loaded_weight: torch.tensor, loaded_weight: torch.Tensor,
tp_rank: int, tp_rank: int,
): ):
# Load grouped weight scales for group quantization # Load grouped weight scales for group quantization
...@@ -580,7 +221,7 @@ class FusedMoE(torch.nn.Module): ...@@ -580,7 +221,7 @@ class FusedMoE(torch.nn.Module):
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_dim: int, shard_dim: int,
shard_id: str, shard_id: str,
loaded_weight: torch.tensor, loaded_weight: torch.Tensor,
tp_rank: int, tp_rank: int,
): ):
# for per channel weight quantization # for per channel weight quantization
...@@ -600,7 +241,7 @@ class FusedMoE(torch.nn.Module): ...@@ -600,7 +241,7 @@ class FusedMoE(torch.nn.Module):
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_dim: int, shard_dim: int,
shard_id: str, shard_id: str,
loaded_weight: torch.tensor, loaded_weight: torch.Tensor,
tp_rank: int, tp_rank: int,
): ):
...@@ -645,7 +286,7 @@ class FusedMoE(torch.nn.Module): ...@@ -645,7 +286,7 @@ class FusedMoE(torch.nn.Module):
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_dim: int, shard_dim: int,
shard_id: str, shard_id: str,
loaded_weight: torch.tensor, loaded_weight: torch.Tensor,
tp_rank: int, tp_rank: int,
): ):
"""Load w2 weights for down projection. """Load w2 weights for down projection.
...@@ -717,7 +358,7 @@ class FusedMoE(torch.nn.Module): ...@@ -717,7 +358,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str, shard_id: str,
expert_data: torch.Tensor, expert_data: torch.Tensor,
shard_dim: int, shard_dim: int,
loaded_weight: torch.tensor, loaded_weight: torch.Tensor,
tp_rank: int, tp_rank: int,
): ):
......
...@@ -19,15 +19,11 @@ import torch ...@@ -19,15 +19,11 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import ( from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location_dispatch import ( from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
import builtins import builtins
import inspect import inspect
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union from typing import Callable, Dict, Optional, Type, Union
import torch import torch
...@@ -45,7 +43,6 @@ except ImportError: ...@@ -45,7 +43,6 @@ except ImportError:
) = QQQConfig = Int8TpuConfig = DummyConfig ) = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
...@@ -66,6 +63,10 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -66,6 +63,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
) )
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.qoq import QoQConfig
from sglang.srt.layers.quantization.utils import (
get_dynamic_override,
get_linear_quant_method,
)
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
...@@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -120,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
return QUANTIZATION_METHODS[quantization] return QUANTIZATION_METHODS[quantization]
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
# Move import here to avoid circular import. This is only used in monkey patching
# of vllm's QuantizationConfig.
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if (
get_dynamic_override( # noqa: E712
cloned_config, layer_name=prefix # noqa: E712
)
== False
): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from sglang.srt.layers.linear import ( from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
LinearBase, from sglang.srt.layers.quantization.base_config import (
LinearMethodBase, LinearMethodBase,
UnquantizedLinearMethod, QuantizationConfig,
) )
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig): ...@@ -81,7 +82,7 @@ class AWQConfig(QuantizationConfig):
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
...@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig): ...@@ -92,7 +93,8 @@ class AWQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["LinearMethodBase"]: ) -> Optional[LinearMethodBase]:
from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert): if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
......
...@@ -18,14 +18,14 @@ class QuantizeMethodBase(ABC): ...@@ -18,14 +18,14 @@ class QuantizeMethodBase(ABC):
"""Create weights for a layer. """Create weights for a layer.
The weights will be set as attributes of the layer.""" The weights will be set as attributes of the layer."""
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor. """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer.""" Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError()
def process_weights_after_loading(self, layer: nn.Module) -> None: def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading. """Process the weight after loading.
...@@ -35,6 +35,74 @@ class QuantizeMethodBase(ABC): ...@@ -35,6 +35,74 @@ class QuantizeMethodBase(ABC):
return return
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError()
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError()
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError()
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
) -> torch.Tensor:
raise NotImplementedError()
class QuantizationConfig(ABC): class QuantizationConfig(ABC):
"""Base class for quantization configs.""" """Base class for quantization configs."""
...@@ -46,12 +114,12 @@ class QuantizationConfig(ABC): ...@@ -46,12 +114,12 @@ class QuantizationConfig(ABC):
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> str:
"""Name of the quantization method.""" """Name of the quantization method."""
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
"""List of supported activation dtypes.""" """List of supported activation dtypes."""
raise NotImplementedError raise NotImplementedError()
@classmethod @classmethod
@abstractmethod @abstractmethod
...@@ -62,19 +130,19 @@ class QuantizationConfig(ABC): ...@@ -62,19 +130,19 @@ class QuantizationConfig(ABC):
This requirement is due to the custom CUDA kernels used by the This requirement is due to the custom CUDA kernels used by the
quantization method. quantization method.
""" """
raise NotImplementedError raise NotImplementedError()
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_config_filenames() -> List[str]: def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory.""" """List of filenames to search for in the model directory."""
raise NotImplementedError raise NotImplementedError()
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
raise NotImplementedError raise NotImplementedError()
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
...@@ -117,7 +185,7 @@ class QuantizationConfig(ABC): ...@@ -117,7 +185,7 @@ class QuantizationConfig(ABC):
The quantize method. None if the given layer doesn't support quant The quantize method. None if the given layer doesn't support quant
method. method.
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -125,7 +193,7 @@ class QuantizationConfig(ABC): ...@@ -125,7 +193,7 @@ class QuantizationConfig(ABC):
For now, this is only used by AWQ. For now, this is only used by AWQ.
""" """
raise NotImplementedError raise NotImplementedError()
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool: def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -7,17 +9,15 @@ import torch ...@@ -7,17 +9,15 @@ import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -78,7 +78,7 @@ class BlockInt8Config(QuantizationConfig): ...@@ -78,7 +78,7 @@ class BlockInt8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config:
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_int8_serialized = "int8" in quant_method is_checkpoint_int8_serialized = "int8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
...@@ -93,7 +93,8 @@ class BlockInt8Config(QuantizationConfig): ...@@ -93,7 +93,8 @@ class BlockInt8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -230,7 +231,7 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -230,7 +231,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
) )
class BlockInt8MoEMethod: class BlockInt8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8. """MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and Supports loading INT8 checkpoints with static weight scale and
dynamic activation scale. dynamic activation scale.
...@@ -242,25 +243,7 @@ class BlockInt8MoEMethod: ...@@ -242,25 +243,7 @@ class BlockInt8MoEMethod:
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __new__(cls, *args, **kwargs): def __init__(self, quant_config: BlockInt8Config):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized assert self.quant_config.is_checkpoint_int8_serialized
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import logging import logging
from contextlib import suppress from contextlib import suppress
...@@ -18,12 +19,8 @@ from compressed_tensors.quantization import ( ...@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( ...@@ -40,6 +37,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format, is_activation_quantization_format,
should_ignore_layer, should_ignore_layer,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try: try:
import vllm import vllm
...@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -97,7 +95,7 @@ class CompressedTensorsConfig(QuantizationConfig):
self.config = config self.config = config
self.packed_modules_mapping = packed_modules_mapping self.packed_modules_mapping = packed_modules_mapping
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> CompressedTensorsLinearMethod:
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
...@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -117,7 +115,8 @@ class CompressedTensorsConfig(QuantizationConfig):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names # TODO (@robertgshaw2): support module names
...@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -138,7 +137,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return None return None
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
ignore: List[str] = cast(List[str], config.get("ignore", [])) ignore: List[str] = cast(List[str], config.get("ignore", []))
quant_format = cast(str, config.get("format")) quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config) target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
...@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -357,7 +356,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def _get_scheme_from_parts( def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel self, weight_quant: BaseModel, input_quant: BaseModel
) -> "CompressedTensorsScheme": ) -> CompressedTensorsScheme:
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
...@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -435,7 +434,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_scheme( def get_scheme(
self, layer: torch.nn.Module, layer_name: Optional[str] = None self, layer: torch.nn.Module, layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]: ) -> Optional[CompressedTensorsScheme]:
""" """
compressed-tensors supports non uniform in the following way: compressed-tensors supports non uniform in the following way:
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -28,17 +30,14 @@ except ImportError: ...@@ -28,17 +30,14 @@ except ImportError:
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -56,6 +55,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
convert_to_channelwise, convert_to_channelwise,
...@@ -77,6 +77,9 @@ from sglang.srt.utils import ( ...@@ -77,6 +77,9 @@ from sglang.srt.utils import (
use_intel_amx_backend, use_intel_amx_backend,
) )
if TYPE_CHECKING:
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
...@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig): ...@@ -152,7 +155,7 @@ class Fp8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
...@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig): ...@@ -167,7 +170,8 @@ class Fp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -200,7 +204,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]): def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
...@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -486,7 +490,7 @@ class Fp8LinearMethod(LinearMethodBase):
) )
class Fp8MoEMethod: class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8. """MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
...@@ -499,25 +503,7 @@ class Fp8MoEMethod: ...@@ -499,25 +503,7 @@ class Fp8MoEMethod:
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __new__(cls, *args, **kwargs): def __init__(self, quant_config: Fp8Config):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
...@@ -1169,6 +1155,254 @@ class Fp8MoEMethod: ...@@ -1169,6 +1155,254 @@ class Fp8MoEMethod:
return None return None
class Fp8EPMoEMethod(Fp8MoEMethod):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
def create_weights(
self,
layer: Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts_per_partition,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic"
else:
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
if self.block_quant
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts_per_partition, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
layer.w13_weight_scale = torch.nn.Parameter(
torch.ones(
layer.num_experts_per_partition,
dtype=torch.float32,
device=w13_weight.device,
),
requires_grad=False,
)
for expert in range(layer.num_experts_per_partition):
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
)
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
if self.quant_config.activation_scheme == "static":
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
layer.w13_weight_scale = torch.nn.Parameter(
torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False,
)
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(
w13_weight, requires_grad=False
)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
Supports loading kv-cache scaling factors from FP8 checkpoints. Supports loading kv-cache scaling factors from FP8 checkpoints.
......
from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from fractions import Fraction from fractions import Fraction
...@@ -5,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -5,7 +7,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
...@@ -16,6 +17,8 @@ from sglang.srt.layers.parameter import ( ...@@ -16,6 +17,8 @@ from sglang.srt.layers.parameter import (
permute_param_layout_, permute_param_layout_,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -34,7 +37,11 @@ from sglang.srt.layers.quantization.marlin_utils import ( ...@@ -34,7 +37,11 @@ from sglang.srt.layers.quantization.marlin_utils import (
verify_marlin_supported, verify_marlin_supported,
) )
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols from sglang.srt.layers.quantization.utils import (
get_linear_quant_method,
replace_parameter,
unpack_cols,
)
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -49,8 +56,6 @@ if _is_cuda: ...@@ -49,8 +56,6 @@ if _is_cuda:
from sgl_kernel import fused_marlin_moe from sgl_kernel import fused_marlin_moe
FusedMoEMethodBase = QuantizeMethodBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -179,7 +184,7 @@ class GPTQConfig(QuantizationConfig): ...@@ -179,7 +184,7 @@ class GPTQConfig(QuantizationConfig):
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic dynamic = {} if dynamic is None else dynamic
...@@ -191,10 +196,10 @@ class GPTQConfig(QuantizationConfig): ...@@ -191,10 +196,10 @@ class GPTQConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["LinearMethodBase"]: ) -> Optional[LinearMethodBase]:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
...@@ -303,7 +308,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -303,7 +308,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return ["quantize_config.json"] return ["quantize_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig:
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
dynamic = {} if dynamic is None else dynamic dynamic = {} if dynamic is None else dynamic
...@@ -354,7 +359,6 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -354,7 +359,6 @@ class GPTQMarlinConfig(QuantizationConfig):
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import get_linear_quant_method
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
...@@ -832,6 +836,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -832,6 +836,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
): ):
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.linear import set_weight_attrs
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
intermediate_size = extra_weight_attrs.pop("intermediate_size") intermediate_size = extra_weight_attrs.pop("intermediate_size")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
from __future__ import annotations
import logging import logging
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
import numpy import numpy
import torch import torch
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedvLLMParameter, PackedvLLMParameter,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
)
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability from sglang.srt.utils import get_device_capability
if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
except ImportError: except ImportError:
...@@ -617,7 +623,10 @@ class MarlinConfig(QuantizationConfig): ...@@ -617,7 +623,10 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]: ) -> Optional[MarlinLinearMethod]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase) or ( if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized isinstance(layer, ParallelLMHead) and self.lm_head_quantized
): ):
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional ...@@ -6,14 +7,11 @@ from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -23,6 +21,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
is_sm100_supported, is_sm100_supported,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
convert_to_channelwise, convert_to_channelwise,
is_layer_skipped, is_layer_skipped,
...@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -86,7 +85,7 @@ class ModelOptFp8Config(QuantizationConfig):
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get( kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
"kv_cache_quant_algo" "kv_cache_quant_algo"
...@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -109,7 +108,11 @@ class ModelOptFp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if self.exclude_modules and any( if self.exclude_modules and any(
module in prefix module in prefix
or ( or (
...@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -125,9 +128,6 @@ class ModelOptFp8Config(QuantizationConfig):
if self.kv_cache_quant_method and isinstance(layer, RadixAttention): if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
# Add MoE support
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self) return ModelOptFp8MoEMethod(self)
...@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): ...@@ -246,7 +246,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super().__init__(quant_config) super().__init__(quant_config)
class ModelOptFp8MoEMethod: class ModelOptFp8MoEMethod(FusedMoEMethodBase):
"""MoE method for ModelOpt FP8. """MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale. Supports loading FP8 checkpoints with static weight scale and activation scale.
...@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod: ...@@ -254,30 +254,6 @@ class ModelOptFp8MoEMethod:
quant_config: The ModelOpt quantization config. quant_config: The ModelOpt quantization config.
""" """
def __new__(cls, *args, **kwargs):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp8Config): def __init__(self, quant_config: ModelOptFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
...@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -514,7 +490,7 @@ class ModelOptFp4Config(QuantizationConfig):
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config": def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
quant_config = cls.get_from_keys(config, ["quantization"]) quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"] quant_method = quant_config["quant_algo"]
if not quant_method in ["FP8", "NVFP4"]: if not quant_method in ["FP8", "NVFP4"]:
...@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -559,7 +535,8 @@ class ModelOptFp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase): ...@@ -740,31 +717,13 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
return out.view(*output_shape) return out.view(*output_shape)
class ModelOptNvFp4FusedMoEMethod: class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
""" """
MoE Method for FP4 Quantization with Blockscales and PerTensorScales MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args: Args:
quant_config: NVFP4 Quant Config quant_config: NVFP4 Quant Config
""" """
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp4Config): def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config self.quant_config = quant_config
if not is_sm100_supported(): if not is_sm100_supported():
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -7,13 +8,14 @@ import torch ...@@ -7,13 +8,14 @@ import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import get_device_capability, set_weight_attrs from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -118,7 +120,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -118,7 +120,7 @@ class MoeWNA16Config(QuantizationConfig):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": def from_config(cls, config: Dict[str, Any]) -> MoeWNA16Config:
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
...@@ -177,8 +179,9 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -177,8 +179,9 @@ class MoeWNA16Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
# avoid circular import # avoid circular import
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if is_layer_skipped_quant(prefix, self.modules_to_not_convert): if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
...@@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): ...@@ -209,32 +212,13 @@ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert) return any(module_name in prefix for module_name in modules_to_not_convert)
class MoeWNA16Method: class MoeWNA16Method(FusedMoEMethodBase):
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization. """Linear method for MOE WNA16 (W8A16/W4A16) quantization.
Args: Args:
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
""" """
def __new__(cls, *args, **kwargs):
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: MoeWNA16Config): def __init__(self, quant_config: MoeWNA16Config):
self.quant_config = quant_config self.quant_config = quant_config
......
from typing import Any, Callable, Dict, List, Optional from __future__ import annotations
from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig): ...@@ -71,7 +72,7 @@ class QoQConfig(QuantizationConfig):
return 80 return 80
@classmethod @classmethod
def get_name(self) -> str: def get_name(cls) -> str:
return "qoq" return "qoq"
@classmethod @classmethod
...@@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig): ...@@ -83,7 +84,7 @@ class QoQConfig(QuantizationConfig):
] ]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "QoQConfig": def from_config(cls, config: Dict[str, Any]) -> QoQConfig:
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size) return cls(weight_bits, group_size)
...@@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig): ...@@ -92,7 +93,7 @@ class QoQConfig(QuantizationConfig):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
import importlib
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from sglang.srt.custom_op import CustomOp
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizeMethodBase,
)
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
is_cpu,
is_hip,
set_weight_attrs,
use_intel_amx_backend,
)
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
_is_cpu_amx_available = cpu_has_amx_support()
_is_hip = is_hip()
_is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe import fused_moe
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer."""
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["weight"])
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if use_intel_amx_backend(layer):
return torch.ops.sgl_kernel.weight_packed_linear(
x, layer.weight, bias, True # is_vnni
)
return F.linear(x, layer.weight, bias)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward,
)
else:
triton_kernel_moe_forward = None
else:
fused_experts = None # type: ignore
triton_kernel_moe_forward = None
self.moe_forward_native = moe_forward_native
self.fused_experts = fused_experts
self.triton_kernel_moe_forward = triton_kernel_moe_forward
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels:
w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
w13_weight = torch.nn.Parameter(
torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
intermediate_size,
)
if self.use_triton_kernels:
w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
# Pack weight for get better performance on CPU
if _is_cpu and _is_cpu_amx_available:
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
if self.use_triton_kernels:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
x = x * topk_weights.to(x.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # topk_weights must be FP32 (float32)
return fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
False, # inplace # See [Note] inplace should be False in fused_experts.
False, # use_int8_w8a8
False, # use_fp8_w8a16
None, # w1_scale
None, # w2_scale
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
else:
return self.moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_npu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return self.moe_forward_native(
layer,
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
inplace,
no_combine,
routed_scaling_factor,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("The TPU backend currently does not support MoE.")
forward_native = forward_cpu
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
layer: torch.nn.Module,
num_experts_per_partition: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
2 * intermediate_size,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts_per_partition,
hidden_size,
intermediate_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w2_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
raise NotImplementedError
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from __future__ import annotations
import re
from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
import numpy import numpy
import torch import torch
...@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant ...@@ -10,6 +14,9 @@ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -147,6 +154,94 @@ def replace_parameter( ...@@ -147,6 +154,94 @@ def replace_parameter(
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.unquant import (
UnquantizedEmbeddingMethod,
UnquantizedLinearMethod,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if get_dynamic_override(cloned_config, layer_name=prefix) is False:
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
def get_pack_factor(num_bits): def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits return 32 // num_bits
......
from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
...@@ -5,12 +7,13 @@ import torch ...@@ -5,12 +7,13 @@ import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig): ...@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config": def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_fp8_serialized = "fp8" in quant_method
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
...@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig): ...@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig): ...@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
return [] return []
class W4AFp8MoEMethod: class W4AFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: W4AFp8Config): def __init__(self, quant_config: W4AFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
......
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
...@@ -64,7 +67,7 @@ class W8A8Fp8Config(QuantizationConfig): ...@@ -64,7 +67,7 @@ class W8A8Fp8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
quant_method = cls.get_from_keys(config, ["quant_method"]) quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ( is_checkpoint_fp8_serialized = (
"compressed-tensors" in quant_method or "w8a8_fp8" in quant_method "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
...@@ -75,7 +78,7 @@ class W8A8Fp8Config(QuantizationConfig): ...@@ -75,7 +78,7 @@ class W8A8Fp8Config(QuantizationConfig):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -183,7 +186,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -183,7 +186,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
) )
class W8A8FP8MoEMethod: class W8A8FP8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8. """MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
...@@ -194,25 +197,7 @@ class W8A8FP8MoEMethod: ...@@ -194,25 +197,7 @@ class W8A8FP8MoEMethod:
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __new__(cls, *args, **kwargs): def __init__(self, quant_config: W8A8Fp8Config):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment