Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90b139cf
Unverified
Commit
90b139cf
authored
Sep 24, 2025
by
Saman A. Pour
Committed by
GitHub
Sep 24, 2025
Browse files
Enable Fbgemm NVFP4 on Dense models (#25609)
Signed-off-by:
Saman Keon
<
samanamp@outlook.com
>
parent
4492e3a5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
5 deletions
+89
-5
benchmarks/kernels/bench_nvfp4_gemm.py
benchmarks/kernels/bench_nvfp4_gemm.py
+61
-4
vllm/envs.py
vllm/envs.py
+4
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+24
-0
No files found.
benchmarks/kernels/bench_nvfp4_gemm.py
View file @
90b139cf
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
argparse
import
argparse
import
copy
import
copy
import
itertools
import
itertools
import
os
import
torch
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
...
@@ -23,21 +24,45 @@ PROVIDER_CFGS = {
...
@@ -23,21 +24,45 @@ PROVIDER_CFGS = {
"torch-bf16"
:
dict
(
enabled
=
True
),
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
"fbgemm-nvfp4"
:
dict
(
fbgemm
=
True
,
no_a_quant
=
False
,
enabled
=
True
),
"fbgemm-nvfp4-noquant"
:
dict
(
fbgemm
=
True
,
no_a_quant
=
True
,
enabled
=
True
),
}
}
_needs_fbgemm
=
any
(
v
.
get
(
"fbgemm"
,
False
)
for
v
in
PROVIDER_CFGS
.
values
()
if
v
.
get
(
"enabled"
,
False
)
)
if
_needs_fbgemm
:
try
:
from
fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize
import
(
triton_scale_nvfp4_quant
,
)
except
ImportError
:
print
(
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
"These providers will be skipped. Please install fbgemm_gpu with: "
"'pip install fbgemm-gpu-genai' to run them."
)
# Disable FBGEMM providers so the benchmark can run.
for
cfg
in
PROVIDER_CFGS
.
values
():
if
cfg
.
get
(
"fbgemm"
):
cfg
[
"enabled"
]
=
False
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
device
:
str
):
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
device
:
str
,
cfg
):
# Compute global scale for weight
# Compute global scale for weight
b_amax
=
torch
.
abs
(
b
).
max
().
to
(
torch
.
float32
)
b_amax
=
torch
.
abs
(
b
).
max
().
to
(
torch
.
float32
)
b_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
b_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
b_fp4
,
scale_b_fp4
=
ops
.
scaled_fp4_quant
(
b
,
b_global_scale
)
if
"fbgemm"
in
cfg
and
cfg
[
"fbgemm"
]:
b_fp4
,
scale_b_fp4
=
triton_scale_nvfp4_quant
(
b
,
b_global_scale
)
else
:
b_fp4
,
scale_b_fp4
=
ops
.
scaled_fp4_quant
(
b
,
b_global_scale
)
return
b_fp4
,
scale_b_fp4
,
b_global_scale
return
b_fp4
,
scale_b_fp4
,
b_global_scale
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp4
,
scale_b_fp4
,
b_global_scale
=
_quant_weight_nvfp4
(
b
,
device
)
b_fp4
,
scale_b_fp4
,
b_global_scale
=
_quant_weight_nvfp4
(
b
,
device
,
cfg
)
# Compute global scale for activation
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
...
@@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
...
@@ -46,6 +71,35 @@ def build_nvfp4_runner(cfg, a, b, dtype, device):
# Alpha for the GEMM operation
# Alpha for the GEMM operation
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
if
"fbgemm"
in
cfg
and
cfg
[
"fbgemm"
]:
if
cfg
[
"no_a_quant"
]:
a_fp4
,
scale_a_fp4
=
triton_scale_nvfp4_quant
(
a
,
a_global_scale
)
def
run
():
return
torch
.
ops
.
fbgemm
.
f4f4bf16
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
global_scale
=
alpha
,
use_mx
=
False
,
)
return
run
else
:
def
run
():
a_fp4
,
scale_a_fp4
=
triton_scale_nvfp4_quant
(
a
,
a_global_scale
)
return
torch
.
ops
.
fbgemm
.
f4f4bf16
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
global_scale
=
alpha
,
use_mx
=
False
,
)
return
run
if
cfg
[
"no_a_quant"
]:
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
# Pre-quantize activation
...
@@ -130,10 +184,13 @@ if __name__ == "__main__":
...
@@ -130,10 +184,13 @@ if __name__ == "__main__":
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
save_dir
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
benchmark
.
run
(
benchmark
.
run
(
print_data
=
True
,
print_data
=
True
,
show_plots
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
save_path
=
save_dir
,
N
=
N
,
N
=
N
,
K
=
K
,
K
=
K
,
)
)
...
...
vllm/envs.py
View file @
90b139cf
...
@@ -201,6 +201,7 @@ if TYPE_CHECKING:
...
@@ -201,6 +201,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
:
bool
=
True
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
:
bool
=
True
VLLM_USE_NCCL_SYMM_MEM
:
bool
=
False
VLLM_USE_NCCL_SYMM_MEM
:
bool
=
False
VLLM_NCCL_INCLUDE_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_INCLUDE_PATH
:
Optional
[
str
]
=
None
VLLM_USE_FBGEMM
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1452,7 +1453,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# NCCL header path
# NCCL header path
"VLLM_NCCL_INCLUDE_PATH"
:
"VLLM_NCCL_INCLUDE_PATH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_NCCL_INCLUDE_PATH"
,
None
),
lambda
:
os
.
environ
.
get
(
"VLLM_NCCL_INCLUDE_PATH"
,
None
),
# Flag to enable FBGemm kernels on model execution
"VLLM_USE_FBGEMM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FBGEMM"
,
"0"
))),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
@@ -1548,6 +1550,7 @@ def compute_hash() -> str:
...
@@ -1548,6 +1550,7 @@ def compute_hash() -> str:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN"
,
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN"
,
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE"
,
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE"
,
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING"
,
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING"
,
"VLLM_USE_FBGEMM"
,
]
]
for
key
in
environment_variables_to_hash
:
for
key
in
environment_variables_to_hash
:
# if this goes out of sync with environment_variables,
# if this goes out of sync with environment_variables,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
90b139cf
...
@@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -30,8 +30,20 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
if
envs
.
VLLM_USE_TRTLLM_FP4_GEMM
:
if
envs
.
VLLM_USE_TRTLLM_FP4_GEMM
:
assert
has_flashinfer
(),
"TRTLLM FP4 GEMM requires FlashInfer"
assert
has_flashinfer
(),
"TRTLLM FP4 GEMM requires FlashInfer"
self
.
backend
=
"flashinfer-trtllm"
self
.
backend
=
"flashinfer-trtllm"
logger
.
info_once
(
"Using flashinfer-trtllm for FP4"
)
elif
envs
.
VLLM_USE_FBGEMM
:
self
.
backend
=
"fbgemm"
try
:
import
fbgemm_gpu
# noqa: F401
except
ImportError
as
exc
:
raise
ImportError
(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
)
from
exc
logger
.
info_once
(
"Using FGBEMM-GPU-GENAI for FP4"
)
elif
has_flashinfer
():
elif
has_flashinfer
():
self
.
backend
=
"flashinfer-cutlass"
self
.
backend
=
"flashinfer-cutlass"
logger
.
info_once
(
"Using flashinfer-cutlass for FP4"
)
else
:
else
:
self
.
backend
=
"cutlass"
self
.
backend
=
"cutlass"
self
.
group_size
=
16
self
.
group_size
=
16
...
@@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -116,6 +128,9 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
else
:
else
:
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
if
self
.
backend
==
"fbgemm"
:
swizzled_weight_scale
=
swizzled_weight_scale
.
view
(
-
1
).
view
(
torch
.
uint8
)
layer
.
weight_scale
=
Parameter
(
swizzled_weight_scale
,
layer
.
weight_scale
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
...
@@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -153,6 +168,15 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"trtllm"
)
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"trtllm"
)
elif
self
.
backend
==
"flashinfer-cutlass"
:
elif
self
.
backend
==
"flashinfer-cutlass"
:
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"cutlass"
)
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
"cutlass"
)
elif
self
.
backend
==
"fbgemm"
:
out
=
torch
.
ops
.
fbgemm
.
f4f4bf16
(
x_fp4
,
layer
.
weight_packed
,
x_blockscale
.
view
(
-
1
).
view
(
torch
.
uint8
),
layer
.
weight_scale
,
layer
.
alpha
,
use_mx
=
False
,
).
to
(
output_dtype
)
else
:
else
:
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment