# SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Optional import numpy import torch from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedvLLMParameter, ) from sglang.srt.layers.quantization.base_config import ( LinearMethodBase, QuantizationConfig, ) from sglang.srt.layers.quantization.utils import ( get_scalar_types, pack_cols, unpack_cols, ) from sglang.srt.utils import get_device_capability, is_cuda if TYPE_CHECKING: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE try: from vllm import _custom_ops as ops except ImportError: ops = None _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import gptq_marlin_gemm logger = logging.getLogger(__name__) ScalarType, scalar_types = get_scalar_types() GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # In case there is a performance issue with Marlin, the variable below can be # changed to False, which allows Marlin to perform global reductions in fp16 # precision (instead of fp32), and therefore, save on some memory movements. USE_FP32_REDUCE_DEFAULT = True # For binary size and compile time, we don't support the same types for with and # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( has_zp: Optional[bool] = None, include_fp_type: bool = True, device_capability: Optional[int] = None, ): if device_capability is None: major, minor = get_device_capability() capability = major * 10 + minor device_capability = -1 if capability is None else capability if device_capability < 80: return [] # - has_zp is True: return quant_types that has zero points # - has_zp is False: return quant_types that has not zero points # - has_zp is None: both if has_zp is None: types0 = query_marlin_supported_quant_types( False, include_fp_type, device_capability ) types1 = query_marlin_supported_quant_types( True, include_fp_type, device_capability ) return types0 + types1 if has_zp: # AWQ style, unsigned + runtime zero-point return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias res = [scalar_types.uint4b8, scalar_types.uint8b128] if include_fp_type: res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] return res def _check_marlin_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, device_capability: Optional[int] = None, ) -> tuple[bool, Optional[str]]: if device_capability is None: major, minor = get_device_capability() capability = major * 10 + minor device_capability = -1 if capability is None else capability supported_types = query_marlin_supported_quant_types( has_zp, True, device_capability ) if quant_type not in supported_types: return ( False, f"Marlin does not support weight_bits = {quant_type}. " f"Only types = {supported_types} " f"are supported (for group_size = {group_size}, " f"device_capability = {device_capability}, zp = {has_zp}).", ) if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: return ( False, f"Marlin does not support group_size = {group_size}. " f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " "are supported.", ) return True, None def check_marlin_supported( quant_type: ScalarType, group_size: int, has_zp: bool = False, device_capability: Optional[int] = None, ) -> bool: cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) return cond def verify_marlin_supported( quant_type: ScalarType, group_size: int, has_zp: bool = False ) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) def verify_marlin_supports_shape( output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int, ) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq." ) # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible " f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq." ) if group_size < input_size and input_size_per_partition % group_size != 0: raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " "with --quantization gptq." ) def check_marlin_supports_shape( output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int, ) -> tuple[bool, Optional[str]]: try: verify_marlin_supports_shape( output_size_per_partition, input_size_per_partition, input_size, group_size ) except ValueError as e: return False, e.__str__() return True, None def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: output_size_per_partition = ( getattr(layer, "output_size_per_partition", None) or layer.output_size ) input_size_per_partition = ( getattr(layer, "input_size_per_partition", None) or layer.input_size ) return check_marlin_supports_shape( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=layer.input_size, group_size=group_size, )[0] def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input # moe marlin requires the activation to be silu supports_activation = layer.moe_runner_config.activation == "silu" # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) # moe marlin requires n % 128 == 0 and k % 64 == 0 supports_shape = ( hidden_size % 128 == 0 and intermediate_size_per_partition % max(64, group_size) == 0 ) supports_group_size = group_size in [-1, 32, 64, 128] return ( supports_shape and supports_group_size and supports_router_weight and supports_activation ) def marlin_make_workspace( device: torch.device, max_blocks_per_sm: int = 1 ) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace # size. The num of threadblocks is is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count return torch.zeros( sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False ) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) def marlin_repeat_scales_on_all_ranks( act_order: bool, group_size: int, is_row_parallel: bool ) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 return act_order or (is_channelwise and is_row_parallel) def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: return torch.nn.Parameter( torch.empty(0, dtype=torch.int, device=device), requires_grad=False ) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: return torch.nn.Parameter( torch.empty(0, dtype=torch.int, device=device), requires_grad=False ) def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single: list[int] = [] for i in range(4): scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single def marlin_permute_scales( s: torch.Tensor, size_k: int, size_n: int, group_size: int ) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, size_n: int, group_size: int, ): num_experts = s.shape[0] output = torch.empty( (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype, ) for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output def marlin_zero_points( zp: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] # Interleave column dim (for the dequantize code) and pack it to int32 if num_bits == 4: interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) elif num_bits == 8: interleave = numpy.array([0, 2, 1, 3]) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() zp = zp.reshape((-1, size_n)).contiguous() zp = pack_cols(zp, num_bits, size_k, size_n) return zp def awq_to_marlin_zero_points( q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation # and pack it back. q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) # Undo interleaving (use argsort(..) to get inverse perm) if num_bits == 4: undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) elif num_bits == 8: undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) else: raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() q_zp = q_zp.reshape((-1, size_n)).contiguous() marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) return marlin_zp def moe_awq_to_marlin_zero_points( q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int ): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), device=q_zp_packed.device, dtype=q_zp_packed.dtype, ) for e in range(num_experts): output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) return output def maybe_warn_marlin_atomic_add(device, dtype): if torch.compiler.is_dynamo_compiling(): return device_capability = torch.cuda.get_device_capability(device) if device_capability[0] < 9 and dtype == torch.bfloat16: logger.info_once( "You are running Marlin kernel with bf16 on GPUs before SM90. " "You can consider change to fp16 to achieve better performance " "if possible." ) def maybe_warn_marlin_atomic_add_env(): if torch.compiler.is_dynamo_compiling(): return # TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False if True: return # if envs.VLLM_MARLIN_USE_ATOMIC_ADD: # return logger.info_once( "Marlin kernel can achieve better performance for small size_n " "with experimental use_atomic_add feature. " "You can consider set environment variable " "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible." ) def should_use_atomic_add_reduce( m: int, n: int, k: int, device: torch.device, dtype: torch.dtype ) -> bool: # the performance of atomicAdd is better than global reduce # only when m*n is small and k is large if n >= 2048 or k < 2048 or device.type != "cuda": return False # disable atomicAdd reduce by default, # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 # TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False if not True: maybe_warn_marlin_atomic_add_env() return False # sm8x doesn't support atomicAdd + bfloat16 natively device_capability = torch.cuda.get_device_capability(device) if device_capability[0] < 9 and dtype == torch.bfloat16: maybe_warn_marlin_atomic_add(device, dtype) return False return True def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zp: torch.Tensor, g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) use_atomic_add = should_use_atomic_add_reduce( m=reshaped_x.size(0), n=output_size_per_partition, k=reshaped_x.size(1), device=input.device, dtype=input.dtype, ) output = gptq_marlin_gemm( reshaped_x, None, weight, weight_scale, None, weight_zp, g_idx, g_idx_sort_indices, workspace, wtype, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, ) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) def apply_awq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zp: torch.Tensor, g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) use_atomic_add = should_use_atomic_add_reduce( m=reshaped_x.size(0), n=output_size_per_partition, k=reshaped_x.size(1), device=input.device, dtype=input.dtype, ) output = gptq_marlin_gemm( reshaped_x, None, weight, weight_scale, None, weight_zp, g_idx, g_idx_sort_indices, workspace, quant_type, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, ) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) class MarlinConfig(QuantizationConfig): """Config class for Marlin. Reference: https://github.com/IST-DASLab/marlin/tree/master """ def __init__( self, group_size: int, lm_head_quantized: bool, ) -> None: super().__init__() # Group size for the quantization. self.group_size = group_size self.lm_head_quantized = lm_head_quantized if self.group_size != 128 and self.group_size != -1: raise ValueError( "Currently, only group size 128 and -1 (channelwise) " "is supported for Marlin, but got group_size of " f"{self.group_size}" ) # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // 4 # Tile size used by marlin kernels. self.tile_size = 16 # Min out_features dim self.min_n_threads = 64 # Min in_features dim self.min_k_threads = 128 # Max parallel problems to solve at once (improves large # batch performance) self.max_parallel = 16 # Permutation length used by the marlin kernels. self.perm_len = 1024 def __repr__(self) -> str: return ( f"MarlinConfig(group_size={self.group_size}, " f"lm_head_quantized={self.lm_head_quantized})" ) @classmethod def get_name(cls) -> str: return "marlin" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half] @classmethod # Need to figure it out def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(group_size, lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_marlin_format: bool is_marlin_format = hf_quant_cfg.get( "checkpoint_format" ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) is_valid_user_quant = ( user_quant is None or user_quant == "gptq" or user_quant == "marlin" ) if is_marlin_format and is_valid_user_quant: msg = "The model is serialized in {} format. Using {} kernel.".format( cls.get_name(), cls.get_name() ) logger.info(msg) return cls.get_name() return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[MarlinLinearMethod]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): return MarlinLinearMethod(self) return None class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. Args: quant_config: The Marlin quantization config. """ def __init__(self, quant_config: MarlinConfig): self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del output_size # Unused. weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}" ) # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.min_n_threads != 0: raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " f"min_n_threads = {self.quant_config.min_n_threads}." ) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " f"pack_factor = {self.quant_config.pack_factor}." ) # Validate input_size_per_partition if input_size_per_partition % self.quant_config.min_k_threads != 0: raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"min_k_threads = {self.quant_config.min_k_threads}." ) if ( self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0 ): raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"group_size = {self.quant_config.group_size}." ) # Check that we have at least 4 tiles horizontally in the shard num_tiles_per_perm = self.quant_config.perm_len // ( self.quant_config.tile_size**2 ) if output_size_per_partition % num_tiles_per_perm != 0: raise ValueError("Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.tile_size, output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, marlin_tile_size=self.quant_config.tile_size, weight_loader=weight_loader, ) # Determine if channelwise or not input_groups = ( 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size ) weight_scale_args = { "data": torch.empty( input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, ), "weight_loader": weight_loader, } if input_groups == 1: scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) else: scales = GroupQuantScaleParameter( output_dim=1, input_dim=0, **weight_scale_args ) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( output_size_per_partition // self.quant_config.min_n_threads ) * self.quant_config.max_parallel workspace = BasevLLMParameter( data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), weight_loader=weight_loader, ) layer.register_parameter("B", qweight) layer.register_parameter("s", scales) layer.register_parameter("workspace", workspace) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # required by torch.compile layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False) layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False) layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = layer.B scales = layer.s workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) size_m = x_2d.shape[0] size_k = x_2d.shape[1] size_n = scales.shape[1] output_2d = ops.marlin_gemm( x_2d, qweight, scales, workspace, size_m, size_n, size_k ) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add return output