Unverified Commit a4bd661f authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Enable FlashInfer DeepGEMM swapAB on SM90 by default (#34924)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 3ef9fd0f
...@@ -5,6 +5,7 @@ from collections.abc import Callable ...@@ -5,6 +5,7 @@ from collections.abc import Callable
import pytest import pytest
from vllm.config import PassConfig from vllm.config import PassConfig
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
from .common import ( from .common import (
INDUCTOR_GRAPH_PARTITION, INDUCTOR_GRAPH_PARTITION,
...@@ -50,6 +51,10 @@ def test_tp1_fp8_fusions( ...@@ -50,6 +51,10 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test, run_e2e_fusion_test,
monkeypatch, monkeypatch,
): ):
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
# Flashinfer block FP8 GEMM has internal quantization, so it can't
# be fused with other ops.
pytest.skip("FlashInfer block FP8 GEMM not supported")
if use_deepgemm and is_blackwell(): if use_deepgemm and is_blackwell():
# TODO(luka) DeepGEMM uses different quants, matching not supported # TODO(luka) DeepGEMM uses different quants, matching not supported
# - on Blackwell, uses a special quant fp8, currently not supported # - on Blackwell, uses a special quant fp8, currently not supported
......
...@@ -159,7 +159,7 @@ if TYPE_CHECKING: ...@@ -159,7 +159,7 @@ if TYPE_CHECKING:
"relax", "relax",
] = "relax" ] = "relax"
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = False VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER: bool = True
VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False
...@@ -1198,7 +1198,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1198,7 +1198,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Allow use of FlashInfer FP8 block-scale GEMM for linear layers. # Allow use of FlashInfer FP8 block-scale GEMM for linear layers.
# This uses TensorRT-LLM kernels and requires SM90+ (Hopper). # This uses TensorRT-LLM kernels and requires SM90+ (Hopper).
"VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool( "VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER": lambda: bool(
int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "0")) int(os.getenv("VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER", "1"))
), ),
# Allow use of FlashInfer BF16 MoE kernels for fused moe ops. # Allow use of FlashInfer BF16 MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment