Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
...@@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
select_cutlass_fp8_gemm_impl, select_cutlass_fp8_gemm_impl,
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear, apply_fp4_marlin_linear,
is_fp4_marlin_supported, is_fp4_marlin_supported,
...@@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# now, the layer is quantized, handle it here # now, the layer is quantized, handle it here
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return self.LinearMethodCls(self) quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return self.FusedMoEMethodCls(quant_config=self, layer=layer) quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
return None return None
...@@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None: def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = "none" self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None: if envs.VLLM_NVFP4_GEMM_BACKEND is None:
...@@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
bias=bias, bias=bias,
input_dtype=self.marlin_input_dtype,
) )
output_dtype = x.dtype output_dtype = x.dtype
...@@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.marlin_input_dtype = None
self.flashinfer_moe_backend = None self.flashinfer_moe_backend = None
if self.allow_flashinfer: if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend() self.flashinfer_moe_backend = get_flashinfer_moe_backend()
...@@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ...@@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
workspace=layer.workspace, input_dtype=self.marlin_input_dtype,
) )
elif self.allow_flashinfer: elif self.allow_flashinfer:
......
...@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin,
) )
...@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig): ...@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu(): if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config) return IpexMxfp4MoEMethod(layer.moe_config)
else: else:
return Mxfp4MoEMethod(layer.moe_config) quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention. # TODO: Add support for MXFP4 Attention.
logger.debug_once( logger.debug_once(
...@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = ( self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size get_current_vllm_config().compilation_config.max_cudagraph_capture_size
...@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN: if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer) prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif ( elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
...@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
) )
assert _can_support_mxfp4( assert _can_support_mxfp4(
......
...@@ -9,6 +9,11 @@ import vllm.envs as envs ...@@ -9,6 +9,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
...@@ -286,10 +291,10 @@ def get_scale_perms(): ...@@ -286,10 +291,10 @@ def get_scale_perms():
def marlin_permute_scales( def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms() scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1 and not is_a_8bit:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else: else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
...@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: ...@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
return s.reshape(*origin_shape).contiguous() return s.reshape(*origin_shape).contiguous()
def marlin_act_int8_process_scales(s: torch.Tensor):
a_scales_scale_factor = 1 / 4096 * s.max().float()
s = s / s.max() * 4096
s = s.round().to(torch.int16).view(s.dtype)
return s, a_scales_scale_factor
def marlin_moe_permute_scales( def marlin_moe_permute_scales(
s: torch.Tensor, s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
size_k: int,
size_n: int,
group_size: int,
): ):
num_experts = s.shape[0] num_experts = s.shape[0]
output = torch.empty( output = torch.empty(
...@@ -319,12 +328,12 @@ def marlin_moe_permute_scales( ...@@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
return output return output
def marlin_zero_points( def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the # Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA # "single" permutation, since zero-points are applied on every MMA
...@@ -339,7 +348,8 @@ def marlin_zero_points( ...@@ -339,7 +348,8 @@ def marlin_zero_points(
else: else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() if not is_a_8bit:
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous() zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n) zp = pack_cols(zp, num_bits, size_k, size_n)
...@@ -347,7 +357,11 @@ def marlin_zero_points( ...@@ -347,7 +357,11 @@ def marlin_zero_points(
def awq_to_marlin_zero_points( def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim. # AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer. # In addition, the values are permuted based on dequantizer.
...@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points( ...@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous() q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
return marlin_zp return marlin_zp
def moe_awq_to_marlin_zero_points( def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
): ):
num_experts = q_zp_packed.shape[0] num_experts = q_zp_packed.shape[0]
output = torch.empty( output = torch.empty(
...@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points( ...@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
dtype=q_zp_packed.dtype, dtype=q_zp_packed.dtype,
) )
for e in range(num_experts): for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) output[e] = awq_to_marlin_zero_points(
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
)
return output return output
...@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce( ...@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
return True return True
_quant_fp8_method: QuantFP8 | None = None
def get__quant_fp8_method() -> QuantFP8:
global _quant_fp8_method
if _quant_fp8_method is None:
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
return torch.int8
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
)
_ = get__quant_fp8_method()
return torch.float8_e4m3fn
else:
return
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
x = x.reshape(-1, x.shape[-1])
if quant_dtype == torch.int8:
return per_token_quant_int8(x)
elif quant_dtype == torch.float8_e4m3fn:
return get__quant_fp8_method()(x)
else:
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear( ...@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
is_k_full: bool, is_k_full: bool,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
...@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear( ...@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
assert wtype == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert wtype == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
weight_zp, weight_zp,
g_idx, g_idx,
...@@ -493,8 +572,10 @@ def apply_awq_marlin_linear( ...@@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
quant_type: ScalarType, quant_type: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
...@@ -507,12 +588,20 @@ def apply_awq_marlin_linear( ...@@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
weight_zp, weight_zp,
g_idx, g_idx,
...@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear( ...@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
quant_type: ScalarType, quant_type: ScalarType,
output_size_per_partition: int, output_size_per_partition: int,
input_size_per_partition: int, input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,) out_shape = input.shape[:-1] + (output_size_per_partition,)
...@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear( ...@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
dtype=input.dtype, dtype=input.dtype,
) )
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
bias, bias,
weight_scale, weight_scale,
a_scales,
None, None,
None, None,
None, None,
......
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales): ...@@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales):
# convert to half first, we would convert to fp8 later # convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half) marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization # fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1 marlin_scales.size(0), -1
...@@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales): ...@@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales):
return marlin_scales return marlin_scales
def mxfp4_marlin_process_scales(marlin_scales): def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization # fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( if input_dtype is None or input_dtype.itemsize == 2:
marlin_scales.size(0), -1 marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
) marlin_scales.size(0), -1
)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
if input_dtype == torch.float8_e4m3fn:
marlin_scales = marlin_scales.view(torch.uint8)
assert marlin_scales.max() <= 249
# exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6
marlin_scales = marlin_scales + 6
marlin_scales = marlin_scales.view(torch.float8_e8m0fnu)
return marlin_scales return marlin_scales
...@@ -99,6 +96,7 @@ def apply_fp4_marlin_linear( ...@@ -99,6 +96,7 @@ def apply_fp4_marlin_linear(
size_n: int, size_n: int,
size_k: int, size_k: int,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor: ) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the # For GPUs that lack FP4 hardware support, we can leverage the
...@@ -111,12 +109,24 @@ def apply_fp4_marlin_linear( ...@@ -111,12 +109,24 @@ def apply_fp4_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
) )
inputs = reshaped_x
a_scales = None
is_nvfp4 = weight_scale_2 is not None
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a=reshaped_x, a=inputs,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_bias=bias, b_bias=bias,
b_scales=weight_scale, b_scales=weight_scale,
a_scales=a_scales,
global_scale=weight_scale_2, global_scale=weight_scale_2,
b_zeros=None, b_zeros=None,
g_idx=None, g_idx=None,
...@@ -133,7 +143,9 @@ def apply_fp4_marlin_linear( ...@@ -133,7 +143,9 @@ def apply_fp4_marlin_linear(
return output.reshape(out_shape) return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: def prepare_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP4 computation but " "Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will " "FP4 quantization is being used. Weight-only FP4 compression will "
...@@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous() qweight = layer.weight.view(torch.int32).T.contiguous()
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, b_q_weight=qweight,
perm=perm, perm=perm,
size_k=part_size_k, size_k=part_size_k,
size_n=part_size_n, size_n=part_size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
...@@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale = weight_scale.to(param_dtype) weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales( weight_scale = marlin_permute_scales(
s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
if is_nvfp4: if is_nvfp4:
...@@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
else: else:
weight_scale = mxfp4_marlin_process_scales(weight_scale) weight_scale = mxfp4_marlin_process_scales(
weight_scale, input_dtype=input_dtype
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
...@@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
return return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP4 computation but " "Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will " "FP4 quantization is being used. Weight-only FP4 compression will "
...@@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
param_dtype = layer.params_dtype param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device) perm = torch.empty(0, dtype=torch.int, device=device)
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WEIGHT # WEIGHT
# Repack weights to marlin format # Repack weights to marlin format
...@@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
qweight = weight[i].view(torch.int32).T.contiguous() qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
) )
tensor_list.append(marlin_qweight) tensor_list.append(marlin_qweight)
...@@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
scale = scales[i].T scale = scales[i].T
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scale, size_k=size_k, size_n=size_n, group_size=group_size s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
if is_nvfp4: if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales) marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else: else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales) marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales) tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
...@@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
setattr(layer, name, bias) setattr(layer, name, bias)
def rand_marlin_weight_nvfp4_like(weight, group_size): def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported."
assert group_size > 0 assert group_size > 0
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
...@@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): ...@@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = nvfp4_marlin_process_scales(marlin_scales) marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
...@@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): ...@@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
return weight_ref.T, marlin_qweight, marlin_scales, global_scale return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size): def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn, (
"MXFP4 weight + INT8 activation is not supported."
)
assert group_size > 0 assert group_size > 0
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
scales = torch.randint( scales = torch.randint(
100, 110,
125, 120,
(size_n, size_k // group_size), (size_n, size_k // group_size),
dtype=torch.uint8, dtype=torch.uint8,
device=weight.device, device=weight.device,
...@@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): ...@@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size):
).view(size_n, size_k) ).view(size_n, size_k)
weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype)
perm = torch.empty(0, dtype=torch.int, device=device)
fp4_weight = fp4_weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), b_q_weight=fp4_weight,
perm=torch.empty(0, dtype=torch.int, device=device), perm=perm,
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=4, num_bits=4,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = mxfp4_marlin_process_scales(marlin_scales) marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu) return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_permute_bias, marlin_permute_bias,
marlin_permute_scales, marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce, should_use_atomic_add_reduce,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -45,6 +46,7 @@ def apply_fp8_marlin_linear( ...@@ -45,6 +46,7 @@ def apply_fp8_marlin_linear(
size_n: int, size_n: int,
size_k: int, size_k: int,
bias: torch.Tensor | None, bias: torch.Tensor | None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor: ) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the # For GPUs that lack FP8 hardware support, we can leverage the
...@@ -57,12 +59,21 @@ def apply_fp8_marlin_linear( ...@@ -57,12 +59,21 @@ def apply_fp8_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
) )
inputs = reshaped_x
a_scales = None
if input_dtype is not None and input_dtype.itemsize == 1:
if input_dtype != torch.float8_e4m3fn:
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a=reshaped_x, a=reshaped_x,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_bias=bias, b_bias=bias,
b_scales=weight_scale, b_scales=weight_scale,
a_scales=a_scales,
global_scale=None, global_scale=None,
b_zeros=None, b_zeros=None,
g_idx=None, g_idx=None,
...@@ -80,7 +91,9 @@ def apply_fp8_marlin_linear( ...@@ -80,7 +91,9 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin( def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None: ) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
...@@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin( ...@@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin(
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
) )
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None: if hasattr(layer, "bias") and layer.bias is not None:
...@@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin( ...@@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin( def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None: ) -> None:
logger.warning_once( logger.warning_once(
"Your GPU does not have native support for FP8 computation but " "Your GPU does not have native support for FP8 computation but "
...@@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin( ...@@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin(
tensor_list.append(marlin_scales) tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = fp8_fused_exponent_bias_into_scales(scales) if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False) scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales) setattr(layer, name + "_weight_scale", scales)
...@@ -318,7 +335,11 @@ def pack_fp8_to_int32( ...@@ -318,7 +335,11 @@ def pack_fp8_to_int32(
return int32_tensor.T.contiguous() if size_k_first else int32_tensor return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def marlin_quant_fp8_torch(weight, group_size): def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn
size_n, size_k = weight.shape size_n, size_k = weight.shape
device = weight.device device = weight.device
...@@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size): ...@@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size):
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
perm = torch.empty(0, dtype=torch.int, device=device)
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight, b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device), perm=perm,
size_k=size_k, size_k=size_k,
size_n=size_n, size_n=size_n,
num_bits=8, num_bits=8,
is_a_8bit=is_a_8bit,
) )
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
) )
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import numpy as np import numpy as np
import torch import torch
from vllm.scalar_type import ScalarType from vllm import _custom_ops as ops
from vllm.scalar_type import ScalarType, scalar_types
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
from .quant_utils import ( from .quant_utils import (
...@@ -29,13 +30,19 @@ class MarlinWorkspace: ...@@ -29,13 +30,19 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): def marlin_permute_weights(
q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False
):
assert q_w.shape == (size_k, size_n) assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles if is_a_8bit:
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) # Permute weights to 32x32 marlin tiles
q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile))
else:
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile)) q_w = q_w.reshape((size_k // tile, size_n * tile))
...@@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): ...@@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
return q_w return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm): def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False):
# Permute # Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm) q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit)
# Pack # Pack
pack_factor = get_pack_factor(num_bits) pack_factor = get_pack_factor(num_bits)
...@@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): ...@@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
return q_packed return q_packed
def get_weight_perm(num_bits: int): def get_weight_perm(num_bits: int, is_a_8bit: bool = False):
perm_list: list[int] = [] perm_list: list[int] = []
for i in range(32): if is_a_8bit:
perm1: list[int] = [] for i in range(32):
col = i // 4 perm1 = []
for block in [0, 1]: col = i // 4
for row in [ for block in [0, 1]:
2 * (i % 4), for row in [
2 * (i % 4) + 1, 4 * (i % 4),
2 * (i % 4 + 4), 4 * (i % 4) + 1,
2 * (i % 4 + 4) + 1, 4 * (i % 4) + 2,
]: 4 * (i % 4) + 3,
perm1.append(16 * row + col + 8 * block) 4 * (i % 4 + 4),
for j in range(4): 4 * (i % 4 + 4) + 1,
perm_list.extend([p + 256 * j for p in perm1]) 4 * (i % 4 + 4) + 2,
4 * (i % 4 + 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(2):
perm_list.extend([p + 512 * j for p in perm1])
else:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list) perm = np.array(perm_list)
if num_bits == 4: if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
else:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8: elif num_bits == 8:
interleave = np.array([0, 2, 1, 3]) if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 1, 2, 3])
else:
interleave = np.array([0, 2, 1, 3])
else: else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
...@@ -99,7 +131,10 @@ def marlin_quantize( ...@@ -99,7 +131,10 @@ def marlin_quantize(
group_size: int, group_size: int,
act_order: bool, act_order: bool,
test_perm: torch.Tensor | None = None, test_perm: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
): ):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape size_k, size_n = w.shape
num_bits = quant_type.size_bits num_bits = quant_type.size_bits
...@@ -120,9 +155,15 @@ def marlin_quantize( ...@@ -120,9 +155,15 @@ def marlin_quantize(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm(num_bits) weight_perm = get_weight_perm(num_bits, is_a_8bit)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) marlin_q_w = marlin_weights(
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8:
ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True)
marlin_s = marlin_s * 512
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
...@@ -132,7 +173,13 @@ def marlin_quantize( ...@@ -132,7 +173,13 @@ def marlin_quantize(
return res_list return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): def awq_marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape size_k, size_n = w.shape
# Normalize group_size # Normalize group_size
...@@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int ...@@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int
# Quantize with zp # Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4:
repeated_zp = zp.repeat_interleave(group_size, 0)
q_w_old = q_w
q_w = q_w_old - repeated_zp
q_w[q_w < 0] = 15 - q_w_old[q_w < 0]
s = s * 512
# Reformat to marlin # Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits) weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) marlin_q_w = marlin_weights(
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) )
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
marlin_zp = marlin_zero_points(
zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit
)
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment