Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d74e5f37
Unverified
Commit
d74e5f37
authored
May 11, 2025
by
Jinzhen Lin
Committed by
GitHub
May 10, 2025
Browse files
[Kernel] fp4 marlin kernel (#17687)
Signed-off-by:
Jinzhen Lin
<
linjinzhen@hotmail.com
>
parent
ca66a167
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
2 deletions
+24
-2
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+24
-2
No files found.
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
d74e5f37
...
@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
...
@@ -19,6 +19,20 @@ def is_fp8_marlin_supported():
return
current_platform
.
has_device_capability
(
80
)
return
current_platform
.
has_device_capability
(
80
)
def
fp8_fused_exponent_bias_into_scales
(
scales
):
fp8_exponent
=
4
if
scales
.
dtype
==
torch
.
half
:
target_exponent
=
5
elif
scales
.
dtype
==
torch
.
bfloat16
:
target_exponent
=
8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias
=
2
**
(
target_exponent
-
1
)
-
2
**
(
fp8_exponent
-
1
)
s
=
torch
.
ones_like
(
scales
)
*
2
s
=
s
**
exponent_bias
return
scales
*
s
def
apply_fp8_marlin_linear
(
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
...
@@ -44,6 +58,7 @@ def apply_fp8_marlin_linear(
c
=
None
,
c
=
None
,
b_q_weight
=
weight
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
b_scales
=
weight_scale
,
global_scale
=
None
,
b_zeros
=
None
,
b_zeros
=
None
,
g_idx
=
None
,
g_idx
=
None
,
perm
=
None
,
perm
=
None
,
...
@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -132,8 +147,10 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
# =>(repeat)=> (size_k // block_size[1], size_n)
if
not
size_k_first
:
scales
=
scales
.
T
.
contiguous
()
block_n
=
layer
.
weight_block_size
[
0
]
block_n
=
layer
.
weight_block_size
[
0
]
scales
=
scales
.
T
.
repeat_interleave
(
block_n
,
1
)
scales
=
scales
.
repeat_interleave
(
block_n
,
1
)
# size_n may not divisible by block_size[0]
# size_n may not divisible by block_size[0]
scales
=
scales
[:,
:
part_size_n
]
scales
=
scales
[:,
:
part_size_n
]
...
@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -141,6 +158,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
size_n
=
part_size_n
,
group_size
=
group_size
)
group_size
=
group_size
)
marlin_scales
=
fp8_fused_exponent_bias_into_scales
(
marlin_scales
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
...
@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -239,8 +257,10 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
# block-wise quantization -> group-wise quantization
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if
not
size_k_first
:
scales
=
scales
.
permute
(
0
,
2
,
1
)
block_n
=
layer
.
weight_block_size
[
0
]
block_n
=
layer
.
weight_block_size
[
0
]
scales
=
scales
.
permute
(
0
,
2
,
1
).
repeat_interleave
(
block_n
,
2
)
scales
=
scales
.
repeat_interleave
(
block_n
,
2
)
# size_n may not divisible by block_size[0]
# size_n may not divisible by block_size[0]
scales
=
scales
[...,
:
size_n
].
contiguous
()
scales
=
scales
[...,
:
size_n
].
contiguous
()
...
@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
...
@@ -302,4 +322,6 @@ def marlin_quant_fp8_torch(weight, group_size):
size_n
=
size_n
,
size_n
=
size_n
,
group_size
=
group_size
)
group_size
=
group_size
)
marlin_scales
=
fp8_fused_exponent_bias_into_scales
(
marlin_scales
)
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
return
weight_ref
.
T
,
marlin_qweight
,
marlin_scales
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment