Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8510c10c
Commit
8510c10c
authored
Feb 09, 2026
by
lixh
Browse files
feat: implement FP8 blockwise GEMM with hipblaslt
parent
45a060d6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
75 additions
and
6 deletions
+75
-6
vllm/envs.py
vllm/envs.py
+0
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+72
-5
No files found.
vllm/envs.py
View file @
8510c10c
...
...
@@ -1837,7 +1837,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_W8A8_BACKEND"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_W8A8_BACKEND"
,
"3"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
8510c10c
...
...
@@ -1804,6 +1804,7 @@ def fused_experts(
expert_map
:
torch
.
Tensor
|
None
=
None
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
quant_config
=
FUSED_MOE_UNQUANTIZED_CONFIG
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
8510c10c
...
...
@@ -1001,6 +1001,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
moe_mk
is
not
None
assert
not
self
.
is_monolithic
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
8510c10c
...
...
@@ -47,6 +47,8 @@ from vllm.utils.flashinfer import (
should_use_flashinfer_for_blockscale_fp8_gemm
,
)
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
lmslim
import
quant_ops
from
lmslim.quantize.quant_ops
import
BlockSize
logger
=
init_logger
(
__name__
)
...
...
@@ -357,6 +359,7 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape
:
GroupShape
,
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
use_aiter_and_is_supported
:
bool
=
False
,
use_blaslt
:
bool
=
False
,
):
self
.
weight_group_shape
=
weight_group_shape
self
.
act_quant_group_shape
=
act_quant_group_shape
...
...
@@ -364,14 +367,13 @@ class W8A8BlockFp8LinearOp:
self
.
is_hopper
=
current_platform
.
is_device_capability
(
90
)
self
.
use_deep_gemm_e8m0
=
is_deep_gemm_e8m0_used
()
self
.
is_flashinfer_supported
=
is_flashinfer_fp8_blockscale_gemm_supported
()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self
.
w8a8_blockscale_op
,
self
.
input_quant_op
=
(
self
.
_dispatch_w8a8_blockscale_op
(
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
,
use_blaslt
)
)
self
.
deepgemm_input_quant_op
=
(
...
...
@@ -397,8 +399,14 @@ class W8A8BlockFp8LinearOp:
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]
]
output_shape
=
[]
output_dtype
=
input
.
dtype
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
out_features
=
int
(
weight
.
shape
[
-
1
])
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
out_features
=
int
(
weight
.
shape
[
0
])
if
should_use_flashinfer_for_blockscale_fp8_gemm
(
self
.
is_flashinfer_supported
,
output_dtype
,
input_2d
,
weight
...
...
@@ -413,7 +421,7 @@ class W8A8BlockFp8LinearOp:
output
=
self
.
_run_deepgemm
(
input_2d
,
weight
,
weight_scale
)
else
:
output
=
self
.
w8a8_blockscale_op
(
input_2d
,
weight
,
weight_scale
,
input_scale
out_features
,
input_2d
,
weight
,
weight_scale
,
input_scale
)
if
bias
is
not
None
:
...
...
@@ -535,6 +543,37 @@ class W8A8BlockFp8LinearOp:
input_2d
.
dtype
,
)
def
_run_hipblaslt_blockwise
(
self
,
out_features
:
int
,
input_2d
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
m
,
k
=
input_2d
.
shape
n
=
out_features
if
input_scale
is
None
:
q_input
,
input_scale
=
self
.
input_quant_op
(
input_2d
)
else
:
q_input
=
input_2d
enum_block_size
=
BlockSize
.
block_128x128
if
hasattr
(
self
,
"block_size"
)
and
self
.
block_size
[
0
]
==
64
:
enum_block_size
=
BlockSize
.
block_64x64
output
=
hipblaslt_w8a8_block_fp8_matmul
(
A
=
q_input
,
B
=
weight
,
As
=
input_scale
,
Bs
=
weight_scale
,
block_size
=
enum_block_size
,
output_dtype
=
torch
.
bfloat16
,
)
return
output
def
_run_flashinfer
(
self
,
input_2d
:
torch
.
Tensor
,
...
...
@@ -562,6 +601,7 @@ class W8A8BlockFp8LinearOp:
self
,
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
,
use_blaslt
:
bool
,
)
->
tuple
[
Callable
[
[
...
...
@@ -585,6 +625,16 @@ class W8A8BlockFp8LinearOp:
)
if
use_aiter_and_is_supported
:
return
self
.
_run_aiter
,
None
if
envs
.
VLLM_W8A8_BACKEND
==
3
or
use_blaslt
:
return
(
self
.
_run_hipblaslt_blockwise
,
QuantFP8
(
False
,
self
.
act_quant_group_shape
,
column_major_scales
=
False
,
use_ue8m0
=
False
,
),
)
return
self
.
_run_triton
,
(
QuantFP8
(
False
,
...
...
@@ -1179,6 +1229,19 @@ def get_w8a8_block_fp8_configs(
)
return
None
def
hipblaslt_w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
BlockSize
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
assert
A
.
shape
[
1
]
==
B
.
shape
[
0
]
m
,
k
=
A
.
shape
_
,
n
=
B
.
shape
_
,
d
=
quant_ops
.
hipblaslt_w8a8_blockwise_gemm
(
A
,
B
,
As
,
Bs
,
m
,
n
,
k
,
'NN'
,
output_dtype
,
block_size
,
None
)
return
d
def
w8a8_triton_block_scaled_mm
(
A
:
torch
.
Tensor
,
...
...
@@ -1597,6 +1660,10 @@ def process_fp8_weight_block_strategy(
weight
=
weight
,
weight_scale
=
weight_scale
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
T
.
contiguous
()
weight_scale
=
weight_scale
.
T
.
contiguous
()
else
:
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment