"docs/vscode:/vscode.git/clone" did not exist on "690f37bbe96a301cec8709a55a9d5e716f515683"
Unverified Commit d7e834d6 authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[6/n]decouple quantization implementation from vLLM dependency (#10750)

parent 200a3c0b
...@@ -10,10 +10,6 @@ import torch ...@@ -10,10 +10,6 @@ import torch
try: try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
...@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): ...@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
return original_isinstance(obj, classinfo) return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance builtins.isinstance = patched_isinstance
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names
def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
assert activation == "silu"
assert inplace and not no_combine
kwargs = {
"self": self,
"layer": layer,
"x": x,
"topk_output": topk_output,
}
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
# Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
monkey_patch_quant_configs()
...@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im ...@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
CompressedTensorsMoEMethod, CompressedTensorsMoEMethod,
) )
from sglang.srt.layers.quantization.compressed_tensors.schemes import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsScheme,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import ( from sglang.srt.layers.quantization.compressed_tensors.utils import (
find_matched_target, find_matched_target,
...@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( ...@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try:
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
CompressedTensors24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16,
)
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
...@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
)
if (
self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
):
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size,
)
if ( if (
self.quant_format == CompressionFormat.pack_quantized.value self.quant_format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS and weight_quant.num_bits in WNA16_SUPPORTED_BITS
...@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size=weight_quant.group_size, group_size=weight_quant.group_size,
actorder=weight_quant.actorder, actorder=weight_quant.actorder,
) )
else:
raise ImportError(
"Other method (CompressedTensorsW4A16Sparse24) is not supported now"
)
if is_activation_quantization_format(self.quant_format): if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant): if self._is_fp8_w8a8(weight_quant, input_quant):
...@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# note: input_quant can be None # note: input_quant can be None
if self._is_fp8_w8a16(weight_quant, input_quant): if self._is_fp8_w8a16(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError(
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
)
is_static_input_scheme = input_quant and not input_quant.dynamic is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8( return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
...@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config # Find the "target" in the compressed-tensors config
# that our layer conforms to. # that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep # TODO : add compressed-tensors as dep
# so we do not have to re-write these functions # so we do not have to re-write these functions
# need to make accelerate optional in ct to do this # need to make accelerate optional in ct to do this
...@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant=input_quant, input_quant=input_quant,
sparsity_scheme=sparsity_scheme, sparsity_scheme=sparsity_scheme,
): ):
if not VLLM_AVAILABLE: raise ImportError("CompressedTensors24 is not supported now")
raise ImportError(
"vllm is not installed, to use CompressedTensors24, please install vllm"
)
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config = (
None
if sparsity_scheme is None or sparsity_scheme.format == "dense"
else self.config
)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None: elif weight_quant is None:
logger.warning_once( logger.warning_once(
"Acceleration for non-quantized schemes is " "Acceleration for non-quantized schemes is "
......
...@@ -6,7 +6,7 @@ import enum ...@@ -6,7 +6,7 @@ import enum
import logging import logging
import re import re
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING
try: try:
from sgl_kernel import fused_marlin_moe from sgl_kernel import fused_marlin_moe
...@@ -31,9 +31,13 @@ from sglang.srt.environ import envs ...@@ -31,9 +31,13 @@ from sglang.srt.environ import envs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS from sglang.srt.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
per_tensor_dequantize, per_tensor_dequantize,
...@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import (
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_compiler_backend, get_compiler_backend,
is_cuda,
is_hip, is_hip,
set_weight_attrs, set_weight_attrs,
) )
...@@ -57,6 +62,8 @@ if TYPE_CHECKING: ...@@ -57,6 +62,8 @@ if TYPE_CHECKING:
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _use_aiter: if _use_aiter:
...@@ -64,12 +71,9 @@ if _use_aiter: ...@@ -64,12 +71,9 @@ if _use_aiter:
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1 from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
try:
import vllm # noqa: F401
VLLM_AVAILABLE = True if _is_cuda:
except ImportError: from sgl_kernel import fused_marlin_moe
VLLM_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
weight_quant = quant_config.target_scheme_map["Linear"].get("weights") weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE:
raise ImportError( logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
)
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
...@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
): ):
if self.num_gpu_experts != -1: if self.num_gpu_experts != -1:
num_experts = self.num_gpu_experts num_experts = self.num_gpu_experts
# assert (
# params_dtype == torch.float16
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
# Will transpose the loaded weight along the # Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will # intermediate and hidden dim sizes. Will
...@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
getattr(layer, name).copy_(new_t) getattr(layer, name).copy_(new_t)
del new_t del new_t
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
)
return scale_perm, scale_perm_single
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
):
num_experts = s.shape[0]
output = torch.empty(
(num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
)
for e in range(num_experts):
output[e] = marlin_permute_scales(
s[e], size_k, size_n, group_size, num_bits
)
return output
size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]
num_experts = layer.w13_weight_g_idx.shape[0] num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device device = layer.w13_weight_g_idx.device
...@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False, requires_grad=False,
) )
from vllm import _custom_ops as vllm_ops marlin_w13_qweight = gptq_marlin_moe_repack(
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w13_g_idx_sort_indices, layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor, layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2], layer.w13_weight_packed.shape[2],
self.num_bits, self.num_bits,
) )
replace_tensor("w13_weight_packed", marlin_w13_qweight) replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack( marlin_w2_qweight = gptq_marlin_moe_repack(
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor, layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2], layer.w2_weight_packed.shape[2],
self.num_bits, self.num_bits,
) )
replace_tensor("w2_weight_packed", marlin_w2_qweight) replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale, layer.w13_weight_scale,
size_k13, layer.w13_weight_packed.shape[2],
layer.w13_weight_scale.shape[2], layer.w13_weight_scale.shape[2],
self.group_size, self.group_size,
self.num_bits,
) )
replace_tensor("w13_weight_scale", marlin_w13_scales) replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale, layer.w2_weight_scale,
layer.w2_weight_scale.shape[1] layer.w2_weight_scale.shape[1]
* (self.group_size if self.group_size != -1 else self.packed_factor), * (self.group_size if self.group_size != -1 else self.packed_factor),
size_k2, layer.w2_weight_scale.shape[2],
self.group_size, self.group_size,
self.num_bits,
) )
replace_tensor("w2_weight_scale", marlin_w2_scales) replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
def create_moe_runner( def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
...@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
output = torch.ops.vllm.fused_marlin_moe( output = fused_marlin_moe(
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_weight_packed,
......
...@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme ...@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
__all__ = [ __all__ = [
"CompressedTensorsScheme", "CompressedTensorsScheme",
"CompressedTensorsW8A8Fp8", "CompressedTensorsW8A8Fp8",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW8A16Fp8",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Int8",
"CompressedTensorsWNA16",
"WNA16_SUPPORTED_BITS",
] ]
...@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import ( ...@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import (
from sglang.srt.layers.quantization.compressed_tensors.schemes import ( from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from sglang.srt.layers.quantization.utils import convert_to_channelwise from sglang.srt.layers.quantization.utils import convert_to_channelwise
try:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
MARLIN_FP8_AVAILABLE = True
except ImportError:
MARLIN_FP8_AVAILABLE = False
def apply_fp8_marlin_linear(*args, **kwargs):
raise ImportError("vllm is not installed")
def prepare_fp8_layer_for_marlin(*args, **kwargs):
raise ImportError("vllm is not installed")
__all__ = ["CompressedTensorsW8A16Fp8"] __all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
...@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
if not MARLIN_FP8_AVAILABLE:
raise ImportError(
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# ampere and up # ampere and up
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Callable, Optional
import torch
from compressed_tensors.quantization import ActivationOrdering
# yapf conflicts with isort for this block
# yapf: disable
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
permute_param_layout_,
)
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from sglang.srt.layers.quantization.marlin_utils import (
MarlinLinearLayerConfig,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_permute_scales,
marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx,
marlin_zero_points,
)
from sglang.srt.layers.quantization.utils import (
get_scalar_types,
replace_parameter,
unpack_cols,
)
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gptq_marlin_repack
ScalarType, scalar_types = get_scalar_types()
logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128
}
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
symmetric: Optional[bool] = True,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
if not self.symmetric else
WNA16_SUPPORTED_TYPES_MAP[num_bits])
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
self.kernel_config = MarlinLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx
)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
)
}
zeros_args = {
"weight_loader":
weight_loader,
"data":
torch.zeros(
output_size_per_partition // self.pack_factor,
scales_and_zp_size,
dtype=torch.int32,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if not self.symmetric:
qzeros = PackedvLLMParameter(input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
self.w_q_name = "weight_packed"
self.w_s_name = "weight_scale"
self.w_zp_name = "weight_zero_point"
self.w_gidx_name = "weight_g_idx"
device = getattr(layer, self.w_q_name).device
c = self.kernel_config
check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(device)
def _transform_param(
layer: torch.nn.Module, name: Optional[str], fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
)
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
_transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
_transform_param(layer, self.w_q_name, transform_w_q)
_transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
c = self.kernel_config
def _get_weight_params(
layer: torch.nn.Module,
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor], # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias,
)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import numpy import numpy
...@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] ...@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
USE_FP32_REDUCE_DEFAULT = True USE_FP32_REDUCE_DEFAULT = True
@dataclass
class MarlinLinearLayerConfig:
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
# For binary size and compile time, we don't support the same types for with and # For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl # TODO: we may want to move this into the C++ so its closer to the actual impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment