Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66c98874
Unverified
Commit
66c98874
authored
Dec 24, 2025
by
Kevin McKay
Committed by
GitHub
Dec 24, 2025
Browse files
[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization (#31179)
Signed-off-by:
c0de128
<
kevin.mckay@outlook.com
>
parent
1ff67df1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+8
-4
No files found.
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
66c98874
...
@@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
...
@@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
M
,
N
=
input
.
size
()
M
,
N
=
input
.
size
()
N_2
=
N
//
2
N_2
=
N
//
2
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
output
is
None
:
if
output
is
None
:
output
=
torch
.
empty
((
M
,
N_2
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input
.
device
)
output
=
torch
.
empty
((
M
,
N_2
),
dtype
=
fp8_dtype
,
device
=
input
.
device
)
output_scales
=
torch
.
empty
(
output_scales
=
torch
.
empty
(
((
N_2
//
GROUP_SIZE
),
M
),
dtype
=
torch
.
float32
,
device
=
input
.
device
((
N_2
//
GROUP_SIZE
),
M
),
dtype
=
torch
.
float32
,
device
=
input
.
device
...
@@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
...
@@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
assert
M
%
BLOCK_M
==
0
assert
M
%
BLOCK_M
==
0
assert
N_2
%
BLOCK_N
==
0
assert
N_2
%
BLOCK_N
==
0
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
# Using the default value (240.0) from pytorch will cause accuracy
fp8_min
=
finfo
.
min
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
fp8_max
=
finfo
.
max
# platforms that use the torch.float8_e4m3fnuz dtype.
finfo
=
torch
.
finfo
(
fp8_dtype
)
fp8_min
=
-
224.0
if
current_platform
.
is_fp8_fnuz
()
else
finfo
.
min
fp8_max
=
224.0
if
current_platform
.
is_fp8_fnuz
()
else
finfo
.
max
# Force even division so we can avoid edgecases within the kernel.
# Force even division so we can avoid edgecases within the kernel.
assert
M
%
BLOCK_M
==
0
assert
M
%
BLOCK_M
==
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment