Unverified Commit d74e5f37 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

[Kernel] fp4 marlin kernel (#17687)


Signed-off-by: default avatarJinzhen Lin <linjinzhen@hotmail.com>
parent ca66a167
...@@ -19,6 +19,20 @@ def is_fp8_marlin_supported(): ...@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
return current_platform.has_device_capability(80) return current_platform.has_device_capability(80)
def fp8_fused_exponent_bias_into_scales(scales):
fp8_exponent = 4
if scales.dtype == torch.half:
target_exponent = 5
elif scales.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1)
s = torch.ones_like(scales) * 2
s = s**exponent_bias
return scales * s
def apply_fp8_marlin_linear( def apply_fp8_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
...@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear( ...@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
b_scales=weight_scale, b_scales=weight_scale,
global_scale=None,
b_zeros=None, b_zeros=None,
g_idx=None, g_idx=None,
perm=None, perm=None,
...@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, ...@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization # block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0])) # (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n) # =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = layer.weight_block_size[0] block_n = layer.weight_block_size[0]
scales = scales.T.repeat_interleave(block_n, 1) scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0] # size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n] scales = scales[:, :part_size_n]
...@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, ...@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k=part_size_k, size_k=part_size_k,
size_n=part_size_n, size_n=part_size_n,
group_size=group_size) group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
...@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, ...@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization # block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0])) # (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n) # =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = layer.weight_block_size[0] block_n = layer.weight_block_size[0]
scales = scales.permute(0, 2, 1).repeat_interleave(block_n, 2) scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0] # size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous() scales = scales[..., :size_n].contiguous()
...@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size): ...@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
size_n=size_n, size_n=size_n,
group_size=group_size) group_size=group_size)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales return weight_ref.T, marlin_qweight, marlin_scales
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment