Unverified Commit 49b87774 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Refactor: move all quantization-related code to `srt/layer/quantization` (#7989)

parent 02404a1e
from __future__ import annotations
import importlib
import sys
from types import MappingProxyType
......@@ -11,21 +13,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import (
LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import (
apply_module_patch,
cpu_has_amx_support,
......@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig):
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
return cls(config)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
......@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
)
class W8A8Int8MoEMethod:
class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
......@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
def __init__(self, quant_config: W8A8Int8Config):
self.quant_config = quant_config
def create_weights(
......@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
class NPU_W8A8MoEMethod:
class NPU_W8A8MoEMethod(FusedMoEMethodBase):
"""MoE method for NPU quantization.
This class search for specific quantization
......
......@@ -5,7 +5,6 @@ from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from sglang.srt.distributed import (
......@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
method_has_implemented_embedding,
)
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
......@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
logger = logging.getLogger(__name__)
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer."""
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment