Unverified Commit d74e5f37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] fp4 marlin kernel (#17687)


Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
parent ca66a167
......@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
return current_platform.has_device_capability(80)
def fp8_fused_exponent_bias_into_scales(scales):
fp8_exponent = 4
if scales.dtype == torch.half:
target_exponent = 5
elif scales.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
s = torch.ones_like(scales) * 2
s = s**exponent_bias
return scales * s
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
......@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
c=None,
b_q_weight=weight,
b_scales=weight_scale,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
......@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1)
scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
......@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
......@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2)
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
......@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
size_n=size_n,
group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment