Unverified Commit 1656ad37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)


Signed-off-by: default avatarJinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin@redhat.com>
parent fa59fe41
......@@ -55,6 +55,9 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
select_cutlass_fp8_gemm_impl,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
......@@ -170,9 +173,15 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# now, the layer is quantized, handle it here
if isinstance(layer, LinearBase):
return self.LinearMethodCls(self)
quant_method = self.LinearMethodCls(self)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
return self.FusedMoEMethodCls(quant_config=self, layer=layer)
quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer)
if getattr(quant_method, "backend", "") == "marlin":
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
return None
......@@ -898,6 +907,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
......@@ -1065,6 +1075,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
input_dtype=self.marlin_input_dtype,
)
output_dtype = x.dtype
......@@ -1124,6 +1135,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.marlin_input_dtype = None
self.flashinfer_moe_backend = None
if self.allow_flashinfer:
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
......@@ -1517,7 +1529,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
)
elif self.allow_flashinfer:
......
......@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
......@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
logger.debug_once(
......@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
......@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer)
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
......@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
)
assert _can_support_mxfp4(
......
......@@ -9,6 +9,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
......@@ -286,10 +291,10 @@ def get_scale_perms():
def marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1:
if group_size < size_k and group_size != -1 and not is_a_8bit:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
......@@ -305,11 +310,15 @@ def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
return s.reshape(*origin_shape).contiguous()
def marlin_act_int8_process_scales(s: torch.Tensor):
a_scales_scale_factor = 1 / 4096 * s.max().float()
s = s / s.max() * 4096
s = s.round().to(torch.int16).view(s.dtype)
return s, a_scales_scale_factor
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False
):
num_experts = s.shape[0]
output = torch.empty(
......@@ -319,12 +328,12 @@ def marlin_moe_permute_scales(
)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, is_a_8bit)
return output
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int, is_a_8bit: bool = False
) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
......@@ -339,7 +348,8 @@ def marlin_zero_points(
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
if not is_a_8bit:
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
......@@ -347,7 +357,11 @@ def marlin_zero_points(
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
......@@ -366,12 +380,16 @@ def awq_to_marlin_zero_points(
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits, is_a_8bit)
return marlin_zp
def moe_awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
q_zp_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
is_a_8bit: bool = False,
):
num_experts = q_zp_packed.shape[0]
output = torch.empty(
......@@ -380,7 +398,9 @@ def moe_awq_to_marlin_zero_points(
dtype=q_zp_packed.dtype,
)
for e in range(num_experts):
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
output[e] = awq_to_marlin_zero_points(
q_zp_packed[e], size_k, size_n, num_bits, is_a_8bit
)
return output
......@@ -432,6 +452,48 @@ def should_use_atomic_add_reduce(
return True
_quant_fp8_method: QuantFP8 | None = None
def get__quant_fp8_method() -> QuantFP8:
global _quant_fp8_method
if _quant_fp8_method is None:
_quant_fp8_method = QuantFP8(False, GroupShape.PER_TOKEN)
return _quant_fp8_method
def get_marlin_input_dtype(prefix):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
return torch.int8
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "fp8":
if not current_platform.is_device_capability(
89
) and not current_platform.is_device_capability(120):
raise ValueError(
"Marlin W4A8-FP8 only support SM89 or SM120 device "
"(It is slower than Marlin W4A16 on other devices). "
"You can consider using W4A8-INT8 instead"
"(set VLLM_MARLIN_INPUT_DTYPE=int8)."
)
_ = get__quant_fp8_method()
return torch.float8_e4m3fn
else:
return
def marlin_quant_input(x: torch.Tensor, quant_dtype: torch.dtype):
x = x.reshape(-1, x.shape[-1])
if quant_dtype == torch.int8:
return per_token_quant_int8(x)
elif quant_dtype == torch.float8_e4m3fn:
return get__quant_fp8_method()(x)
else:
raise ValueError(f"unsupported quant_dtype {quant_dtype}")
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
......@@ -444,8 +506,10 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
......@@ -458,12 +522,27 @@ def apply_gptq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
assert wtype == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert wtype == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
......@@ -493,8 +572,10 @@ def apply_awq_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
......@@ -507,12 +588,20 @@ def apply_awq_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
weight_zp,
g_idx,
......@@ -538,8 +627,10 @@ def apply_rtn_marlin_linear(
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
input_global_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
input_dtype: torch.dtype | None = None,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
......@@ -552,12 +643,20 @@ def apply_rtn_marlin_linear(
dtype=input.dtype,
)
a_scales = None
if input_dtype == torch.int8:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
a_scales,
None,
None,
None,
......
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.platforms import current_platform
......@@ -37,12 +38,6 @@ def nvfp4_marlin_process_scales(marlin_scales):
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
......@@ -62,18 +57,20 @@ def nvfp4_marlin_process_scales(marlin_scales):
return marlin_scales
def mxfp4_marlin_process_scales(marlin_scales):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1
)
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
if input_dtype is None or input_dtype.itemsize == 2:
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
if input_dtype == torch.float8_e4m3fn:
marlin_scales = marlin_scales.view(torch.uint8)
assert marlin_scales.max() <= 249
# exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6
marlin_scales = marlin_scales + 6
marlin_scales = marlin_scales.view(torch.float8_e8m0fnu)
return marlin_scales
......@@ -99,6 +96,7 @@ def apply_fp4_marlin_linear(
size_n: int,
size_k: int,
bias: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
# For GPUs that lack FP4 hardware support, we can leverage the
......@@ -111,12 +109,24 @@ def apply_fp4_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
)
inputs = reshaped_x
a_scales = None
is_nvfp4 = weight_scale_2 is not None
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm(
a=reshaped_x,
a=inputs,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
a_scales=a_scales,
global_scale=weight_scale_2,
b_zeros=None,
g_idx=None,
......@@ -133,7 +143,9 @@ def apply_fp4_marlin_linear(
return output.reshape(out_shape)
def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
......@@ -160,12 +172,14 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = layer.weight.view(torch.int32).T.contiguous()
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
......@@ -178,7 +192,11 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales(
s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size
s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
if is_nvfp4:
......@@ -189,7 +207,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False)
else:
weight_scale = mxfp4_marlin_process_scales(weight_scale)
weight_scale = mxfp4_marlin_process_scales(
weight_scale, input_dtype=input_dtype
)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
......@@ -200,7 +220,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
return
def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_moe_fp4_layer_for_marlin(
layer: torch.nn.Module, input_dtype: torch.dtype | None = None
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression will "
......@@ -220,6 +242,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
param_dtype = layer.params_dtype
layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
# WEIGHT
# Repack weights to marlin format
......@@ -237,7 +260,12 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
qweight = weight[i].view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
tensor_list.append(marlin_qweight)
......@@ -266,12 +294,18 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
scale = scales[i].T
marlin_scales = marlin_permute_scales(
s=scale, size_k=size_k, size_n=size_n, group_size=group_size
s=scale,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
......@@ -301,7 +335,10 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
setattr(layer, name, bias)
def rand_marlin_weight_nvfp4_like(weight, group_size):
def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
assert not is_a_8bit, "NVFP4 weight + INT8/FP8 activation is not supported."
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
......@@ -337,10 +374,15 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
......@@ -349,14 +391,20 @@ def rand_marlin_weight_nvfp4_like(weight, group_size):
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size):
def rand_marlin_weight_mxfp4_like(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn, (
"MXFP4 weight + INT8 activation is not supported."
)
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = torch.randint(
100,
125,
110,
120,
(size_n, size_k // group_size),
dtype=torch.uint8,
device=weight.device,
......@@ -380,18 +428,25 @@ def rand_marlin_weight_mxfp4_like(weight, group_size):
).view(size_n, size_k)
weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype)
perm = torch.empty(0, dtype=torch.int, device=device)
fp4_weight = fp4_weight.view(torch.int32).T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
b_q_weight=fp4_weight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=4,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales, input_dtype=input_dtype)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_quant_input,
should_use_atomic_add_reduce,
)
from vllm.platforms import current_platform
......@@ -45,6 +46,7 @@ def apply_fp8_marlin_linear(
size_n: int,
size_k: int,
bias: torch.Tensor | None,
input_dtype: torch.dtype | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
......@@ -57,12 +59,21 @@ def apply_fp8_marlin_linear(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
)
inputs = reshaped_x
a_scales = None
if input_dtype is not None and input_dtype.itemsize == 1:
if input_dtype != torch.float8_e4m3fn:
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm(
a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
a_scales=a_scales,
global_scale=None,
b_zeros=None,
g_idx=None,
......@@ -80,7 +91,9 @@ def apply_fp8_marlin_linear(
def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
......@@ -162,7 +175,8 @@ def prepare_fp8_layer_for_marlin(
marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
if input_dtype != torch.float8_e4m3fn:
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
......@@ -172,7 +186,9 @@ def prepare_fp8_layer_for_marlin(
def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
layer: torch.nn.Module,
size_k_first: bool = True,
input_dtype: torch.dtype | None = None,
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
......@@ -278,7 +294,8 @@ def prepare_moe_fp8_layer_for_marlin(
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = fp8_fused_exponent_bias_into_scales(scales)
if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
......@@ -318,7 +335,11 @@ def pack_fp8_to_int32(
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def marlin_quant_fp8_torch(weight, group_size):
def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:
assert input_dtype == torch.float8_e4m3fn
size_n, size_k = weight.shape
device = weight.device
......@@ -334,16 +355,22 @@ def marlin_quant_fp8_torch(weight, group_size):
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
perm = torch.empty(0, dtype=torch.int, device=device)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8,
is_a_8bit=is_a_8bit,
)
marlin_scales = marlin_permute_scales(
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
s=scales.T,
size_k=size_k,
size_n=size_n,
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
......
......@@ -5,7 +5,8 @@
import numpy as np
import torch
from vllm.scalar_type import ScalarType
from vllm import _custom_ops as ops
from vllm.scalar_type import ScalarType, scalar_types
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
from .quant_utils import (
......@@ -29,13 +30,19 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
def marlin_permute_weights(
q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False
):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
if is_a_8bit:
# Permute weights to 32x32 marlin tiles
q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile))
else:
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
......@@ -44,9 +51,9 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit)
# Pack
pack_factor = get_pack_factor(num_bits)
......@@ -63,28 +70,53 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm):
return q_packed
def get_weight_perm(num_bits: int):
def get_weight_perm(num_bits: int, is_a_8bit: bool = False):
perm_list: list[int] = []
for i in range(32):
perm1: list[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
if is_a_8bit:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
4 * (i % 4),
4 * (i % 4) + 1,
4 * (i % 4) + 2,
4 * (i % 4) + 3,
4 * (i % 4 + 4),
4 * (i % 4 + 4) + 1,
4 * (i % 4 + 4) + 2,
4 * (i % 4 + 4) + 3,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(2):
perm_list.extend([p + 512 * j for p in perm1])
else:
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7])
else:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
if is_a_8bit: # noqa: SIM108
interleave = np.array([0, 1, 2, 3])
else:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
......@@ -99,7 +131,10 @@ def marlin_quantize(
group_size: int,
act_order: bool,
test_perm: torch.Tensor | None = None,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape
num_bits = quant_type.size_bits
......@@ -120,9 +155,15 @@ def marlin_quantize(
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
weight_perm = get_weight_perm(num_bits, is_a_8bit)
marlin_q_w = marlin_weights(
q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8:
ops.marlin_int4_fp8_preprocess(marlin_q_w, inplace=True)
marlin_s = marlin_s * 512
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
......@@ -132,7 +173,13 @@ def marlin_quantize(
return res_list
def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int):
def awq_marlin_quantize(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
input_dtype: torch.dtype | None = None,
):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
size_k, size_n = w.shape
# Normalize group_size
......@@ -147,11 +194,22 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int
# Quantize with zp
w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True)
if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4:
repeated_zp = zp.repeat_interleave(group_size, 0)
q_w_old = q_w
q_w = q_w_old - repeated_zp
q_w[q_w < 0] = 15 - q_w_old[q_w < 0]
s = s * 512
# Reformat to marlin
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits)
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
marlin_q_w = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit=is_a_8bit
)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit)
marlin_zp = marlin_zero_points(
zp, num_groups, size_n, quant_type.size_bits, is_a_8bit=is_a_8bit
)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment