"vscode:/vscode.git/clone" did not exist on "a822937abedeefb04d6f93ba7aef8dd8d2484848"
Unverified Commit 49b87774 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Refactor: move all quantization-related code to `srt/layer/quantization` (#7989)

parent 02404a1e
from __future__ import annotations
import importlib import importlib
import sys import sys
from types import MappingProxyType from types import MappingProxyType
...@@ -11,21 +13,19 @@ from sglang.srt.distributed import ( ...@@ -11,21 +13,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import (
LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
apply_module_patch, apply_module_patch,
cpu_has_amx_support, cpu_has_amx_support,
...@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -229,14 +229,14 @@ class W8A8Int8Config(QuantizationConfig):
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
return cls(config) return cls(config)
def get_quant_method( def get_quant_method(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -374,7 +374,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
) )
class W8A8Int8MoEMethod: class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8. """MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
...@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod: ...@@ -385,25 +385,7 @@ class W8A8Int8MoEMethod:
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __new__(cls, *args, **kwargs): def __init__(self, quant_config: W8A8Int8Config):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
...@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase): ...@@ -885,13 +867,15 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.linear import RowParallelLinear
if isinstance(layer, RowParallelLinear): if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank) return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias) return self.quant_method.apply(layer, x, bias)
class NPU_W8A8MoEMethod: class NPU_W8A8MoEMethod(FusedMoEMethodBase):
"""MoE method for NPU quantization. """MoE method for NPU quantization.
This class search for specific quantization This class search for specific quantization
......
...@@ -5,7 +5,6 @@ from dataclasses import dataclass ...@@ -5,7 +5,6 @@ from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -22,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
method_has_implemented_embedding, method_has_implemented_embedding,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
...@@ -32,44 +32,6 @@ _is_cpu = is_cpu() ...@@ -32,44 +32,6 @@ _is_cpu = is_cpu()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer."""
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value.""" """Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to return ((vocab_size + pad_to - 1) // pad_to) * pad_to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment