Unverified Commit 90c20079 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Disable tma_aligned_scales in test_fusions_e2e (#32916)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent d95d6507
......@@ -290,6 +290,9 @@ def test_rms_group_quant(
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
# TODO: remove this after fusion is fixed
monkeypatch.setenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "0")
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig(
......
......@@ -162,6 +162,7 @@ if TYPE_CHECKING:
VLLM_USE_DEEP_GEMM: bool = True
VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True
VLLM_DEEP_GEMM_WARMUP: Literal[
"skip",
"full",
......@@ -1201,6 +1202,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))
),
# Whether to create TMA-aligned scale tensor when DeepGEMM is used.
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1"))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
......
......@@ -379,7 +379,7 @@ class W8A8BlockFp8LinearOp:
False,
self.act_quant_group_shape,
column_major_scales=True,
tma_aligned_scales=True,
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment