Unverified Commit bd7157a0 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent be429d0c
...@@ -416,8 +416,8 @@ steps: ...@@ -416,8 +416,8 @@ steps:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s compile/piecewise/ - pytest -v -s compile/piecewise/
- label: PyTorch Fullgraph Test # 20min - label: PyTorch Fullgraph Test # 22min
timeout_in_minutes: 30 timeout_in_minutes: 35
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
torch_nightly: true torch_nightly: true
source_file_dependencies: source_file_dependencies:
...@@ -425,6 +425,7 @@ steps: ...@@ -425,6 +425,7 @@ steps:
- tests/compile - tests/compile
commands: commands:
- pytest -v -s compile/test_full_graph.py - pytest -v -s compile/test_full_graph.py
- pytest -v -s compile/test_fusions_e2e.py
- label: Kernels Core Operation Test # 48min - label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75 timeout_in_minutes: 75
...@@ -807,8 +808,8 @@ steps: ...@@ -807,8 +808,8 @@ steps:
# Whisper needs spawn method to avoid deadlock # Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
- label: Blackwell Test # 38 min - label: Blackwell Test # 21 min
timeout_in_minutes: 60 timeout_in_minutes: 30
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
gpu: b200 gpu: b200
# optional: true # optional: true
...@@ -821,8 +822,6 @@ steps: ...@@ -821,8 +822,6 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
...@@ -839,15 +838,32 @@ steps: ...@@ -839,15 +838,32 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py # this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval - label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60 timeout_in_minutes: 60
...@@ -1100,7 +1116,7 @@ steps: ...@@ -1100,7 +1116,7 @@ steps:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H200 test ##### ##### H200 test #####
- label: Distrubted Tests (H200) # optional - label: Distributed Tests (H200) # optional
gpu: h200 gpu: h200
optional: true optional: true
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
...@@ -1108,6 +1124,8 @@ steps: ...@@ -1108,6 +1124,8 @@ steps:
commands: commands:
- pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py - pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
......
...@@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ...@@ -392,6 +392,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
......
...@@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant( ...@@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) { double epsilon) {
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int input_stride = input.stride(-2); int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
...@@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant( ...@@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
} }
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
......
...@@ -3,16 +3,22 @@ ...@@ -3,16 +3,22 @@
import weakref import weakref
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
import depyf
from torch import fx from torch import fx
from torch._ops import OpOverload from torch._ops import OpOverload
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger("vllm.tests.compile.backend")
class LazyInitPass(InductorPass): class LazyInitPass(InductorPass):
...@@ -45,20 +51,32 @@ class TestBackend: ...@@ -45,20 +51,32 @@ class TestBackend:
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]): def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
self.custom_passes = list(passes) self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config vllm_config = get_current_vllm_config()
self.inductor_config = compile_config.inductor_compile_config compile_config = vllm_config.compilation_config
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
self.inductor_config["force_disable_caches"] = True self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
if debug_dump_path := vllm_config.compile_debug_dump_path():
logger.debug("Dumping depyf output to %s", debug_dump_path)
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
else:
self.debug_ctx = nullcontext()
def __call__(self, graph: fx.GraphModule, example_inputs): def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph) self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
return compile_fx(graph, example_inputs, config_patches=self.inductor_config) with self.debug_ctx:
return compile_fx(
graph, example_inputs, config_patches=self.inductor_config
)
@with_pattern_match_debug @with_pattern_match_debug
def post_pass(self, graph: fx.Graph): def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph) self.graph_pre_pass = deepcopy(graph)
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
VllmInductorPass.dump_prefix = 0 VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes: for pass_ in self.custom_passes:
...@@ -68,6 +86,7 @@ class TestBackend: ...@@ -68,6 +86,7 @@ class TestBackend:
VllmInductorPass.dump_prefix = None VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph) self.graph_post_pass = deepcopy(graph)
lazy_format_graph_code("graph_post_pass", graph.owning_module)
# assign by reference, will reflect the final state of the graph # assign by reference, will reflect the final state of the graph
self.final_graph = graph self.final_graph = graph
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import tempfile import tempfile
from pathlib import Path
from typing import Any from typing import Any
import pytest import pytest
...@@ -10,8 +10,6 @@ import torch ...@@ -10,8 +10,6 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils import is_torch_equal_or_newer
...@@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test ...@@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: list[str] | None = None): def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{
"dtype": torch.float16,
},
),
( (
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{ {"dtype": torch.float16},
"dtype": torch.float16,
},
), ),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if all: if all:
TEST_MODELS.extend(
[
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{"dtype": torch.float16},
),
]
)
# TODO: figure out why this fails. # TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223 if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append( TEST_MODELS.append(
...@@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): ...@@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
"compilation_mode", "compilation_mode",
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE], [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
) )
@pytest.mark.parametrize("model_info", models_list(all=True)) @pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]], model: str,
model_kwargs: dict[str, Any],
compilation_mode: int, compilation_mode: int,
): ):
model, model_kwargs = model_info if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
with monkeypatch.context(): with monkeypatch.context():
print(f"MODEL={model}") print(f"MODEL={model}")
run_model(compilation_mode, model, model_kwargs) run_model(compilation_mode, model, **model_kwargs)
# TODO(luka) add other supported compilation config scenarios here # TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize( @pytest.mark.parametrize(
"compilation_config, model_info", "compilation_config, model, model_kwargs",
[ [
# additional compile sizes, only some of the models # additional compile sizes, only some of the models
( (
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
model, *model_info,
) )
for model in models_list(all=False) for model_info in models_list(all=False)
] ]
+ [ + [
# RMSNorm + quant fusion, only 8-bit quant models # RMSNorm + quant fusion, only 8-bit quant models
...@@ -117,18 +123,19 @@ def test_full_graph( ...@@ -117,18 +123,19 @@ def test_full_graph(
custom_ops=["+rms_norm"], custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True), pass_config=PassConfig(enable_fusion=True, enable_noop=True),
), ),
model, *model_info,
) )
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
] ]
+ [ + [
# Test depyf integration works # Test depyf integration works
( (
CompilationConfig( CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
debug_dump_path=tempfile.gettempdir(), debug_dump_path=Path(tempfile.gettempdir()),
), ),
("facebook/opt-125m", {}), "facebook/opt-125m",
{},
), ),
] ]
+ [ + [
...@@ -142,9 +149,9 @@ def test_full_graph( ...@@ -142,9 +149,9 @@ def test_full_graph(
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2], compile_sizes=[1, 2],
), ),
model, *model_info,
) )
for model in models_list(all=False) for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev") if is_torch_equal_or_newer("2.9.0.dev")
], ],
) )
...@@ -152,16 +159,24 @@ def test_full_graph( ...@@ -152,16 +159,24 @@ def test_full_graph(
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_custom_compile_config( def test_custom_compile_config(
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]], model: str,
model_kwargs: dict[str, Any],
): ):
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev" "2.9.0.dev"
): ):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs) run_model(compilation_config, model, **model_kwargs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int): ...@@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int):
"calculate_kv_scales": True, "calculate_kv_scales": True,
"max_model_len": 512, "max_model_len": 512,
} }
run_model(compilation_mode, model, model_kwargs) run_model(compilation_mode, model, **model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm): def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
if not is_torch_equal_or_newer("2.9.0.dev"): compilation_config = (
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") compile_config
if isinstance(compile_config, CompilationConfig)
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" else CompilationConfig(level=compile_config)
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
) )
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with (
caplog_vllm.at_level(logging.DEBUG),
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
):
run_model(compilation_config, model, model_kwargs)
try:
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
caplog_vllm.text
)
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(
compile_config: int | CompilationConfig,
model: str,
model_kwargs: dict[str, Any],
):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -227,12 +208,17 @@ def run_model( ...@@ -227,12 +208,17 @@ def run_model(
"The future of AI is", "The future of AI is",
] ]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM( llm = LLM(
model=model, model=model,
enforce_eager=True, compilation_config=compilation_config,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
**model_kwargs, **model_kwargs,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass ...@@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
...@@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module): ...@@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
return y return y
def example_inputs(self, num_tokens=32, hidden_size=128): def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32 return (torch.rand(num_tokens, hidden_size * 2),)
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion: if TEST_FP8 and do_fusion:
...@@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module): ...@@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter( self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype) torch.empty((intermediate_size, hidden_size))
) )
self.norm = RMSNorm(intermediate_size, 1e-05) self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter( self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.ones(intermediate_size, dtype=dtype)
)
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
...@@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module): ...@@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
return norm_output, residual_output return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32 hidden_states = torch.randn((batch_size * seq_len, hidden_size))
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
return (hidden_states, residual) return (hidden_states, residual)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
...@@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module): ...@@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
return q_rotated, k_rotated return q_rotated, k_rotated
def example_inputs(self, num_tokens=32, head_dim=64): def example_inputs(self, num_tokens=32, head_dim=64):
dtype = torch.float16
positions = torch.arange(num_tokens, dtype=torch.long) positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim, dtype=dtype) q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim, dtype=dtype) k = torch.randn(num_tokens, head_dim)
return (positions, q, k) return (positions, q, k)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
...@@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): ...@@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.hidden_size = head_dim * num_heads self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear( self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 self.hidden_size, self.hidden_size * 3, bias=False
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): ...@@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return qkv_updated return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
dtype = torch.float16
hidden_size = head_dim * num_heads hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long) positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states) return (positions, hidden_states)
def ops_in_model(self, do_fusion): def ops_in_model(self, do_fusion):
...@@ -211,48 +209,58 @@ MODELS = [ ...@@ -211,48 +209,58 @@ MODELS = [
] ]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig( vllm_config = VllmConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
) )
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass) with set_current_vllm_config(vllm_config):
backend_no_func = TestBackend(*passes) assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
model = model_class() backend_func = TestBackend(*passes, func_pass)
torch.compile(model, backend=backend_func)(*model.example_inputs()) backend_no_func = TestBackend(*passes)
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# check if the functionalization pass is applied model = model_class()
for op in model.ops_in_model(do_fusion): torch.compile(model, backend=backend_func)(*model.example_inputs())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) torch.compile(model, backend=backend_no_func)(*model.example_inputs())
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized # check if the functionalization pass is applied
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(do_fusion): for op in model.ops_in_model(do_fusion):
if is_func(node, op): find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
found[op] = True assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
for op in model.ops_not_in_model():
if is_func(node, op): # make sure the ops were all de-functionalized
found[op] = True found = dict()
assert all(found[op] for op in model.ops_in_model(do_fusion)) for node in backend_func.graph_post_pass.nodes:
assert all(not found.get(op) for op in model.ops_not_in_model()) for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
...@@ -5,15 +5,18 @@ import pytest ...@@ -5,15 +5,18 @@ import pytest
import torch import torch
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import ( from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
FUSED_OPS, from vllm.compilation.fx_utils import find_op_nodes
QUANT_OPS, from vllm.compilation.matcher_utils import QUANT_OPS
FusedRMSQuantKey,
RMSNormQuantFusionPass,
)
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
...@@ -32,6 +35,9 @@ from .backend import TestBackend ...@@ -32,6 +35,9 @@ from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__( def __init__(
...@@ -45,18 +51,18 @@ class TestModel(torch.nn.Module): ...@@ -45,18 +51,18 @@ class TestModel(torch.nn.Module):
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
quant_scale = ScaleDesc(torch.float32, static, group_shape) quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else: else:
self.scale = [None for _ in range(2)] self.scale = [None for _ in range(3)]
self.w = [ self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2) for _ in range(3)
] ]
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
...@@ -65,8 +71,12 @@ class TestModel(torch.nn.Module): ...@@ -65,8 +71,12 @@ class TestModel(torch.nn.Module):
act_quant_group_shape=group_shape, act_quant_group_shape=group_shape,
) )
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x): def forward(self, x):
resid = torch.sqrt(x) # avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply( x2 = self.fp8_linear.apply(
...@@ -78,24 +88,44 @@ class TestModel(torch.nn.Module): ...@@ -78,24 +88,44 @@ class TestModel(torch.nn.Module):
x3 = self.fp8_linear.apply( x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1] y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
) )
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3
def ops_in_model_before(self): x4 = self.fp8_linear.apply(
return [QUANT_OPS[self.key]] y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
] ]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -105,19 +135,32 @@ class TestModel(torch.nn.Module): ...@@ -105,19 +135,32 @@ class TestModel(torch.nn.Module):
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
) )
def test_fusion_rmsnorm_quant( def test_fusion_rmsnorm_quant(
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch dtype,
hidden_size,
num_tokens,
eps,
static,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"], custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True), pass_config=PassConfig(enable_fusion=True, enable_noop=True),
) ),
) )
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
...@@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant( ...@@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant(
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch) model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
result = model(x) model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model2 = torch.compile(model, backend=backend) model_unfused = torch.compile(model, backend=backend2)
result2 = model2(x) result_unfused = model_unfused(x)
# Higher tol for dynamic, even higher for bfloat16 if dtype == torch.float16:
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3) ATOL, RTOL = (2e-3, 2e-3)
else: else:
ATOL, RTOL = (1e-2, 1e-2) ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2 assert fusion_pass.matched_count == 3
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())
backend.check_before_ops(
# In post-nodes, fused kernels should be there and fp8 quant should not model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after()) backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
...@@ -17,6 +18,7 @@ from vllm.config import ( ...@@ -17,6 +18,7 @@ from vllm.config import (
ModelConfig, ModelConfig,
PassConfig, PassConfig,
VllmConfig, VllmConfig,
set_current_vllm_config,
) )
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import ( ...@@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import (
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape, GroupShape,
QuantFP8,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
...@@ -40,33 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module): ...@@ -40,33 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual): def forward(self, x):
view = hidden_states.reshape(-1, self.hidden_size) # avoid having graph input be an arg to a pattern directly
all_reduce = tensor_model_parallel_all_reduce(view) z = torch.relu(x)
norm = self.norm(all_reduce) x = resid = tensor_model_parallel_all_reduce(z)
return norm y = self.norm[0](x)
def ops_in_model_before(self): z2 = torch.mm(y, self.w[0])
return [torch.ops.vllm.all_reduce.default] x2 = tensor_model_parallel_all_reduce(z2)
def ops_in_model_after(self): y2, resid = self.norm[1](x2, resid)
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): y3, resid = self.norm[2](x3, resid)
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() z4 = torch.mm(y3, self.w[2])
self.hidden_size = hidden_size x4 = tensor_model_parallel_all_reduce(z4)
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
def forward(self, hidden_states, residual): y4, resid = self.norm[3](x4, resid)
view = hidden_states.reshape(-1, self.hidden_size) return y4
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm
def ops_in_model_before(self): def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default] return [torch.ops.vllm.all_reduce.default]
...@@ -75,24 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): ...@@ -75,24 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.scale = torch.rand(1, dtype=torch.float32) self.w = [
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
def forward(self, hidden_states, residual): .t()
view = hidden_states.reshape(-1, self.hidden_size) for _ in range(3)
all_reduce = tensor_model_parallel_all_reduce(view) ]
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant( self.fp8_linear = Fp8LinearOp(
self.output, norm_output.contiguous(), self.scale act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
) )
return self.output, residual_output
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
...@@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): ...@@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def ops_in_model_before(self): def ops_in_model_before(self):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default, torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
else torch.ops.aten.reciprocal.default,
] ]
...@@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): ...@@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
round_up = lambda x, y: (x + y - 1) // y * y wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
rounded_m = round_up(token_num, 128) self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4) wq_gen, wscale_gen = zip(
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
def forward(self, hidden_states, residual): self.wq, self.wscale = list(wq_gen), list(wscale_gen)
view = hidden_states.reshape(-1, self.hidden_size) print(f"{self.wq=}, {self.wscale=}")
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual) def forward(self, hidden_states):
norm_output = norm_output.reshape(-1, norm_output.shape[-1]) # avoid having graph input be an arg to a pattern directly
torch.ops._C.scaled_fp4_quant( z = torch.relu(hidden_states)
self.output, norm_output, self.output_scale, self.scale x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
z2 = cutlass_scaled_fp4_mm(
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
z3 = cutlass_scaled_fp4_mm(
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
) )
return self.output, residual_output, self.output_scale
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
z4 = cutlass_scaled_fp4_mm(
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
...@@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): ...@@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_model", "test_model, enable_quant_fp8_custom_op",
[ [
TestAllReduceRMSNormModel, (TestAllReduceRMSNormModel, False),
TestAllReduceFusedAddRMSNormModel, (TestAllReduceRMSNormStaticQuantFP8Model, True),
TestAllReduceFusedAddRMSNormStaticQuantFP8Model, (TestAllReduceRMSNormStaticQuantFP8Model, False),
# TODO: Enable with torch==2.8.0 (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
], ],
) )
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif( @pytest.mark.skipif(
not find_spec("flashinfer") not find_spec("flashinfer")
...@@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace( ...@@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace(
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
): ):
num_processes = 2 num_processes = 2
if ( if (
...@@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace( ...@@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace(
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn( torch.multiprocessing.spawn(
fn, fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
),
nprocs=nprocs, nprocs=nprocs,
) )
...@@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
): ):
current_platform.seed_everything(0) current_platform.seed_everything(0)
...@@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model(
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig( vllm_config = VllmConfig(
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"] mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
) )
) )
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True enable_fi_allreduce_fusion=True, enable_noop=True
) )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
...@@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model( ...@@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model(
vllm_config.model_config = ModelConfig( vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42 model=model_name, trust_remote_code=True, dtype=dtype, seed=42
) )
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) token_num = batch_size * seq_len
noop_pass = NoOpEliminationPass(vllm_config) model = test_model_cls(hidden_size, token_num)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual) compiled_model(hidden_states)
assert all_reduce_fusion_pass.matched_count == 1 assert all_reduce_fusion_pass.matched_count == 4, (
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) f"{all_reduce_fusion_pass.matched_count=}"
backend.check_after_ops(model.ops_in_model_after()) )
del all_reduce_fusion_pass backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
...@@ -6,14 +6,15 @@ import pytest ...@@ -6,14 +6,15 @@ import pytest
import torch._dynamo import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import ( from vllm.config import (
...@@ -28,21 +29,18 @@ from vllm.config import ( ...@@ -28,21 +29,18 @@ from vllm.config import (
) )
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field
backend: TestBackend | None = None
backend_unfused: TestBackend | None = None
class AttentionQuantPatternModel(torch.nn.Module): class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion.""" """Base model for AttentionQuantPattern fusion."""
...@@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
backend = self.attn.backend backend = self.attn.backend
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend # Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN: if backend == _Backend.ROCM_ATTN:
# k/v as 1st dimention # k/v as 1st dimention
...@@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): ...@@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
) )
MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = []
BACKENDS_FP4: list[_Backend] = []
if current_platform.is_cuda(): if current_platform.is_cuda():
MODELS = [ HEADS = [(64, 8), (40, 8)]
MODELS_FP8 = [
( (
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel, TestAttentionFp8StaticQuantPatternModel,
), )
]
MODELS_FP4 = [
( (
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel, TestAttentionNvfp4QuantPatternModel,
), )
] ]
HEADS = [(64, 8), (40, 8)] BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
MODELS = [ HEADS = [(32, 8), (40, 8)]
MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
] ]
HEADS = [(32, 8), (40, 8)] BACKENDS = [
else: _Backend.ROCM_AITER_UNIFIED_ATTN,
MODELS = [] _Backend.ROCM_ATTN,
HEADS = [] _Backend.TRITON_ATTN,
]
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
...@@ -269,46 +282,36 @@ else: ...@@ -269,46 +282,36 @@ else:
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
) )
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend, model_name, model_class, custom_ops",
[_Backend.FLASHINFER] # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
if current_platform.is_cuda() list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], # quant_fp4 only has the custom impl
) + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True],
) )
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
) )
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern( def test_attention_quant_pattern(
num_qo_heads: int, num_qo_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
batch_size: int, batch_size: int,
dtype: torch.dtype, dtype: torch.dtype,
custom_ops: str,
model_name: str, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, backend: _Backend,
use_inductor_graph_partition: bool,
dist_init, dist_init,
caplog_vllm,
): ):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): custom_ops_list = custom_ops.split(",") if custom_ops else []
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.manual_seed(42) torch.manual_seed(42)
...@@ -322,8 +325,7 @@ def test_attention_quant_pattern( ...@@ -322,8 +325,7 @@ def test_attention_quant_pattern(
scheduler_config=SchedulerConfig(max_num_seqs=1024), scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"], custom_ops=custom_ops_list,
use_inductor_graph_partition=use_inductor_graph_partition,
), ),
cache_config=CacheConfig(cache_dtype="fp8"), cache_config=CacheConfig(cache_dtype="fp8"),
) )
...@@ -358,8 +360,9 @@ def test_attention_quant_pattern( ...@@ -358,8 +360,9 @@ def test_attention_quant_pattern(
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without compilation and fusion # Run model directly without fusion
result_unfused = model_unfused(q, k, v) # Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
# Run model with attn fusion enabled # Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
...@@ -414,16 +417,25 @@ def test_attention_quant_pattern( ...@@ -414,16 +417,25 @@ def test_attention_quant_pattern(
) )
# Check attn fusion support # Check attn fusion support
quant_key = model_class.quant_key quant_key: QuantKey = model_class.quant_key
attn_fusion_supported = [ attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key) layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items() for key, layer in vllm_config.compilation_config.static_forward_context.items()
] ]
if any(attn_fusion_supported): assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
# Check quantization ops in the graph before and after fusion "All layers should support attention fusion"
# Note: fully_replaced=False because query quant ops remain in graph. )
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) # Check quantization ops in the graph before and after fusion
quant_op = (
torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key]
)
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
# access the underlying `AttnFusionPass` on the `LazyInitPass` # access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import itertools
import logging
from collections.abc import Iterable
from typing import Any, NamedTuple
import pytest
import regex as re
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
attention_fusions: int
allreduce_fusions: int | None = None
MODELS_FP8: list[ModelBackendTestCase] = []
MODELS_FP4: list[ModelBackendTestCase] = []
MODELS: list[ModelBackendTestCase] = [] # tp-only
if current_platform.is_cuda():
MODELS_FP8 = [
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
MODELS_FP4 = [
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
# TP only
MODELS = [
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
]
elif current_platform.is_rocm():
MODELS_FP8 = [
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
# TODO(luka) test both in nightly
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
custom_ops=custom_ops_list,
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
# TODO(luka) test both in nightly
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
for op_list in itertools.product(*custom_ops_lists):
yield ",".join(op_list)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda()
or not has_flashinfer()
or not current_platform.has_device_capability(90),
reason="allreduce+rmsnorm fusion requires flashinfer",
)
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_fi_allreduce_fusion=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(level=compile_config)
)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import VllmConfig from vllm.config import ModelConfig, VllmConfig
# dummy custom pass that doesn't inherit # dummy custom pass that doesn't inherit
...@@ -42,7 +42,8 @@ class ProperPass(InductorPass): ...@@ -42,7 +42,8 @@ class ProperPass(InductorPass):
], ],
) )
def test_pass_manager_uuid(callable): def test_pass_manager_uuid(callable):
config = VllmConfig() # Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager() pass_manager = PostGradPassManager()
pass_manager.configure(config) pass_manager.configure(config)
......
...@@ -18,6 +18,8 @@ from vllm.config import ( ...@@ -18,6 +18,8 @@ from vllm.config import (
ModelConfig, ModelConfig,
PassConfig, PassConfig,
VllmConfig, VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
) )
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -42,9 +44,7 @@ prompts = [ ...@@ -42,9 +44,7 @@ prompts = [
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__( def __init__(self, hidden_size=16, intermediate_size=32):
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -95,13 +95,11 @@ class TestModel(torch.nn.Module): ...@@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module): class TestQuantModel(torch.nn.Module):
def __init__( def __init__(self, hidden_size=16, intermediate_size=32):
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.vllm_config = vllm_config self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter( self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False torch.empty((intermediate_size, hidden_size)), requires_grad=False
) )
...@@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model( ...@@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() compilation_config = CompilationConfig(
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig( pass_config=PassConfig(
enable_sequence_parallelism=True, enable_sequence_parallelism=True,
enable_fusion=enable_fusion, enable_fusion=enable_fusion,
enable_noop=True, enable_noop=True,
) )
) # NoOp needed for fusion ) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig( model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42 model=model_name, trust_remote_code=True, dtype=dtype, seed=42
) )
noop_pass = NoOpEliminationPass(vllm_config) vllm_config = VllmConfig(
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) model_config=model_config,
assert ( device_config=device_config,
sequence_parallelism_pass.compilation_config.splitting_ops compilation_config=compilation_config,
== vllm_config.compilation_config.splitting_ops
) )
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
if enable_fusion: with set_current_vllm_config(vllm_config):
fusion_pass = RMSNormQuantFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
passes_for_backend.append(fusion_pass) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
passes_for_backend: list[VllmInductorPass] = [
noop_pass,
sequence_parallelism_pass,
]
passes_for_backend.append(cleanup_pass) if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
backend_no_func = TestBackend(*passes_for_backend) passes_for_backend.append(cleanup_pass)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) model = test_model_cls(hidden_size, hidden_size * 2)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func(hidden_states, residual) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1 compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
# In pre-nodes, all reduce should be there, assert sequence_parallelism_pass.matched_count == 1
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In post-nodes, reduce scatter and all gather should be there, # In pre-nodes, all reduce should be there,
# all reduce should not # reduce scatter and all gather should not
backend_no_func.check_after_ops(model.ops_in_model_after()) backend_no_func.check_before_ops(model.ops_in_model_before())
# check if the functionalization pass is applied # In post-nodes, reduce scatter and all gather should be there,
for op in model.ops_in_model(): # all reduce should not
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) backend_no_func.check_after_ops(model.ops_in_model_after())
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized # check if the functionalization pass is applied
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(): for op in model.ops_in_model():
if is_func(node, op): find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
found[op] = True assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
assert all(found[op] for op in model.ops_in_model())
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
# ruff: noqa import pathlib
from copy import deepcopy
from tblib import pickling_support from tblib import pickling_support
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate # Install support for pickling exceptions so that we can nicely propagate
# failures from tests running in a subprocess. # failures from tests running in a subprocess.
# This should be run before any custom exception subclasses are defined. # This should be run before any custom exception subclasses are defined.
...@@ -40,7 +43,7 @@ from transformers import ( ...@@ -40,7 +43,7 @@ from transformers import (
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
...@@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): ...@@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
yield caplog yield caplog
@pytest.fixture()
def caplog_mp_fork():
"""
This fixture enables capturing logs from a forked MP subprocess.
It should be used in conjunction with caplog_vllm.
By default, subprocess logs do not go through the parent process.
We instead create a queue listener in the parent process which
forwards logs to the logger's other handlers, and add a QueueHandler
to the root logger. Forked subprocesses will inherit the root logger
and pass their messages to the queue, which the listener will forward
to the root logger, which can be captured by caplog.
Note that this workaround only works for fork; with spawn, the subprocess
reinitializes logging and does not automatically inherit the queue.
We'd have to manually pass the queue to the subprocess at the spawn point.
See caplog_mp_spawn below.
"""
@contextlib.contextmanager
def ctx():
import logging.handlers
import multiprocessing as mp
logger_queue: mp.Queue[logging.LogRecord] = mp.Queue()
logger = logging.getLogger()
handlers = logger.handlers
# The listener works on a background thread, not inherited by the child.
queue_listener = logging.handlers.QueueListener(logger_queue, *handlers)
queue_listener.start()
# Add queue handler after creating the listener to avoid cycle
logger.addHandler(logging.handlers.QueueHandler(logger_queue))
yield
queue_listener.stop()
return ctx
class LogHolder:
def __init__(self):
self.text = None
@pytest.fixture()
def caplog_mp_spawn(tmp_path, monkeypatch):
"""
This fixture enables capturing logs from a forked MP subprocess.
It does not require caplog_vllm (but it only contains logs from the child).
By default, subprocess logs do not go through the parent process.
We instead add a FileHandler to the config so the spawned child process
writes its logs to a temp file.
In the parent, we read the file and return the contents.
Note: this method could be extended to fork by either reconfiguring logging
in the parent or using a SocketHandler:
https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501
"""
@contextlib.contextmanager
def ctx(level: int | str):
from vllm.logger import DEFAULT_LOGGING_CONFIG
config_path = tmp_path / "vllm_logging_config.json"
log_path = tmp_path / "vllm.log"
log_holder = LogHolder()
config = deepcopy(DEFAULT_LOGGING_CONFIG)
if envs.VLLM_LOGGING_CONFIG_PATH:
path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH)
assert path.exists()
config = json.loads(path.read_text())
config["loggers"]["vllm"]["handlers"] += ["vllm_file"]
config["handlers"]["vllm_file"] = {
"class": "logging.FileHandler",
"formatter": "vllm",
"level": level,
"filename": log_path.as_posix(),
}
config_path.write_text(json.dumps(config))
with monkeypatch.context() as monkeypatch_ctx:
monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix())
monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1")
yield log_holder
log_holder.text = log_path.read_text()
return ctx
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def num_gpus_available(): def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context """Get number of GPUs without initializing the CUDA context
......
...@@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant( ...@@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max) .clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE) .to(FP8_DTYPE)
) )
return ref_out, ref_scale.view((1,)) return ref_out, ref_scale.view((1, 1))
def native_w8a8_block_matmul( def native_w8a8_block_matmul(
......
...@@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content(): ...@@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content():
assert call_args[1] == "test-streaming-full-text" assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)" assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete" assert call_args[5] == "streaming_complete"
# Add vllm prefix to make sure logs go through the vllm logger
test_logger = init_logger("vllm.test_logger")
def mp_function(**kwargs):
# This function runs in a subprocess
test_logger.warning("This is a subprocess: %s", kwargs.get("a"))
test_logger.error("This is a subprocess error.")
test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b"))
def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork):
with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork():
import multiprocessing
ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in caplog_vllm.text
assert "BBBBB" in caplog_vllm.text
def test_caplog_mp_spawn(caplog_mp_spawn):
with caplog_mp_spawn(logging.DEBUG) as log_holder:
import multiprocessing
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in log_holder.text
assert "BBBBB" in log_holder.text
...@@ -6,6 +6,7 @@ import contextlib ...@@ -6,6 +6,7 @@ import contextlib
import copy import copy
import functools import functools
import importlib import importlib
import itertools
import json import json
import os import os
import random import random
...@@ -15,7 +16,7 @@ import sys ...@@ -15,7 +16,7 @@ import sys
import tempfile import tempfile
import time import time
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable, Iterable
from contextlib import ExitStack, contextmanager, suppress from contextlib import ExitStack, contextmanager, suppress
from multiprocessing import Process from multiprocessing import Process
from pathlib import Path from pathlib import Path
...@@ -1261,3 +1262,23 @@ def check_answers( ...@@ -1261,3 +1262,23 @@ def check_answers(
frac_ok = numok / len(answer) frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}") print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok >= accept_rate assert frac_ok >= accept_rate
def flat_product(*iterables: Iterable[Any]):
"""
Flatten lists of tuples of the cartesian product.
Useful when we want to avoid nested tuples to allow
test params to be unpacked directly from the decorator.
Example:
flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
[
(1, 2, "a"),
(1, 2, "b"),
(3, 4, "a"),
(3, 4, "b"),
]
"""
for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized))
...@@ -40,7 +40,7 @@ from vllm.utils import ( ...@@ -40,7 +40,7 @@ from vllm.utils import (
unique_filepath, unique_filepath,
) )
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test, flat_product
def test_get_open_port(monkeypatch: pytest.MonkeyPatch): def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
...@@ -771,3 +771,25 @@ def test_unique_filepath(): ...@@ -771,3 +771,25 @@ def test_unique_filepath():
paths.add(path) paths.add(path)
assert len(paths) == 10 assert len(paths) == 10
assert len(list(Path(temp_dir).glob("*.txt"))) == 10 assert len(list(Path(temp_dir).glob("*.txt"))) == 10
def test_flat_product():
# Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
assert result1 == [
(1, "a"),
(1, "b"),
(2, "a"),
(2, "b"),
(3, "a"),
(3, "b"),
]
# check that the tuples get flattened
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
assert result2 == [
(1, 2, "a", 5, 6),
(1, 2, "b", 5, 6),
(3, 4, "a", 5, 6),
(3, 4, "b", 5, 6),
]
...@@ -1507,7 +1507,7 @@ def scaled_fp8_quant( ...@@ -1507,7 +1507,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub output, input, scale, scale_ub
) )
else: else:
scale = torch.empty(1, device=input.device, dtype=torch.float32) scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
assert scale.numel() == 1, f"{scale.shape}" assert scale.numel() == 1, f"{scale.shape}"
......
...@@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import ( ...@@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -41,11 +45,8 @@ else: ...@@ -41,11 +45,8 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default if hasattr(torch.ops._C, "scaled_fp4_quant"):
RMS_OP = torch.ops._C.rms_norm.default STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern: class BasePattern:
...@@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern): ...@@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) input, weight = self.rmsnorm_matcher.inputs()
rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, rms_result, weight] # input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern( def pattern(input: torch.Tensor, weight: torch.Tensor):
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
):
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized( rms = self.rmsnorm_matcher(allreduce_output, weight)
RMS_OP,
result=rms_result,
input=allreduce_output,
weight=weight,
epsilon=self.epsilon,
)
# rms_result, allreduce_output
return rms[1], allreduce_output
def replacement( return rms, allreduce_output
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
): def replacement(input: torch.Tensor, weight: torch.Tensor):
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): ...@@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self): def get_inputs(self):
input = torch.empty([4, 4], device=self.device, dtype=self.dtype) input, residual, weight = self.rmsnorm_matcher.inputs()
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) # input goes through allreduce first, always 16-bit
return [ return [residual, input.to(self.dtype), weight]
residual,
input,
weight,
]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized( rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
RMS_ADD_OP, return rms, residual
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
# input, residual
return rms[1], rms[2]
def replacement( def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
...@@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): ...@@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
) )
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern),
first_return_only(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
""" """
...@@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def get_inputs(): def get_inputs():
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) input, weight = self.rmsnorm_matcher.inputs()
rmsnorm_result = torch.empty( _, scale = self.quant_matcher.inputs()
[1, 8, 4], device=self.device, dtype=self.dtype
) # input goes through allreduce first, always 16-bit
quant_result = torch.empty( return [input.to(self.dtype), weight, scale]
[1, 8, 4], device=self.device, dtype=self.quant_dtype
)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized( rms = self.rmsnorm_matcher(all_reduce, weight)
RMS_OP, quant, _ = self.quant_matcher(rms, scale)
result=rmsnorm_result, return quant, all_reduce
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], all_reduce
def replacement( def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
residual=residual, residual=residual,
norm_out=result_rms, norm_out=result_rms,
quant_out=quant_result, quant_out=result_quant,
scale_out=None, scale_out=None,
rms_gamma=weight, rms_gamma=weight,
rms_eps=self.epsilon, rms_eps=self.epsilon,
...@@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def get_inputs(): def get_inputs():
input = torch.empty([4, 4], device=self.device, dtype=self.dtype) input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) # input goes through allreduce first, always 16-bit
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [residual, input.to(self.dtype), weight, scale]
quant_result = torch.empty(
[4, 4], device=self.device, dtype=self.quant_dtype
)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
quant_result,
residual,
input,
weight,
scale,
]
def pattern( def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
fused_add_rmsnorm_out_tuple = auto_functionalized( return quant, res
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
def replacement( def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
): ):
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
residual=residual, residual=residual,
norm_out=None, norm_out=None,
quant_out=quant_result, quant_out=result_quant,
scale_out=None, scale_out=None,
rms_gamma=weight, rms_gamma=weight,
rms_eps=self.epsilon, rms_eps=self.epsilon,
...@@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def get_inputs(): def get_inputs():
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty(
[1, 16, 16], device=self.device, dtype=self.dtype
)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty( input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32 [1, 1], device=self.device, dtype=torch.float32
...@@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight = torch.empty([16], device=self.device, dtype=self.dtype) weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [ return [input, quant_result, weight, input_global_scale, output_scale]
input,
rmsnorm_result,
quant_result,
weight,
input_global_scale,
output_scale,
]
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor, quant_result: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
): ):
all_reduce = tensor_model_parallel_all_reduce(input) all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized( rms = self.rmsnorm_matcher(all_reduce, weight)
RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
output=quant_result, output=quant_result,
input=rmsnorm_out_tuple[1], input=rms,
output_scale=output_scale, output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
) )
...@@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
def replacement( def replacement(
input: torch.Tensor, input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor, quant_result: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
output_scale: torch.Tensor, output_scale: torch.Tensor,
): ):
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device) super().__init__(dtype, device)
self.epsilon = epsilon self.epsilon = epsilon
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass):
def get_inputs(): def get_inputs():
...@@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
): ):
allreduce_output = tensor_model_parallel_all_reduce(input) allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
fused_add_rmsnorm_out_tuple = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized( quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP, STATIC_FP4_QUANT_OP,
output=quant_result, output=quant_result,
input=fused_add_rmsnorm_out_tuple[1], input=rms,
output_scale=output_scale, output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
) )
# quant_out, allreduce_output, output_scale # quant_out, allreduce_output, output_scale
return ( return quant_out_tuple[1], residual, quant_out_tuple[2]
quant_out_tuple[1],
fused_add_rmsnorm_out_tuple[2],
quant_out_tuple[2],
)
def replacement( def replacement(
quant_result: torch.Tensor, quant_result: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment