Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
import tempfile import tempfile
from typing import Any, Optional, Union from pathlib import Path
from typing import Any
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
PassConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { (
"dtype": torch.float16, "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
}), {"dtype": torch.float16},
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { ),
"dtype": torch.float16,
}),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if all: if all:
TEST_MODELS.extend(
[
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{"dtype": torch.float16},
),
]
)
# TODO: figure out why this fails. # TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223 if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { TEST_MODELS.append(
"quantization": "gguf" ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
})) )
if is_quant_method_supported("gptq"): if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { TEST_MODELS.append(
"quantization": "gptq" ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
})) )
if is_quant_method_supported("gptq_marlin"): if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { TEST_MODELS.append(
"quantization": "gptq_marlin" (
})) "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
{"quantization": "gptq_marlin"},
)
)
if is_quant_method_supported("gptq_marlin_24"): if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { TEST_MODELS.append(
"quantization": "gptq_marlin_24" (
})) "alexm-nm/tinyllama-24-marlin24-4bit-g128",
{"quantization": "gptq_marlin_24"},
)
)
if not current_platform.is_rocm() and is_quant_method_supported("awq"): if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { TEST_MODELS.append(
"quantization": "AWQ" ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
})) )
if keywords is None: if keywords is None:
return TEST_MODELS return TEST_MODELS
...@@ -72,110 +80,145 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): ...@@ -72,110 +80,145 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"optimization_level", "compilation_mode",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
) )
@pytest.mark.parametrize("model_info", models_list(all=True)) @pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]], model: str,
optimization_level: int, model_kwargs: dict[str, Any],
compilation_mode: int,
): ):
model, model_kwargs = model_info if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
with monkeypatch.context(): with monkeypatch.context():
print(f"MODEL={model}") print(f"MODEL={model}")
run_model(optimization_level, model, model_kwargs) run_model(compilation_mode, model, **model_kwargs)
# TODO(luka) add other supported compilation config scenarios here # TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize( @pytest.mark.parametrize(
"compilation_config, model_info", "compilation_config, model, model_kwargs",
[ [
# additional compile sizes, only some of the models # additional compile sizes, only some of the models
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
compile_sizes=[1, 2]), model) CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
for model in models_list(all=False) *model_info,
] + [ )
for model_info in models_list(all=False)
]
+ [
# RMSNorm + quant fusion, only 8-bit quant models # RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
custom_ops=["+rms_norm"], CompilationConfig(
pass_config=PassConfig(enable_fusion=True, mode=CompilationMode.VLLM_COMPILE,
enable_noop=True)), model) custom_ops=["+rms_norm"],
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) pass_config=PassConfig(enable_fusion=True, enable_noop=True),
] + [ ),
*model_info,
)
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]
+ [
# Test depyf integration works # Test depyf integration works
(CompilationConfig(level=CompilationLevel.PIECEWISE, (
debug_dump_path=tempfile.gettempdir()), CompilationConfig(
("facebook/opt-125m", {})), mode=CompilationMode.VLLM_COMPILE,
] + [ debug_dump_path=Path(tempfile.gettempdir()),
),
"facebook/opt-125m",
{},
),
]
+ [
# graph inductor partition # graph inductor partition
( (
CompilationConfig( CompilationConfig(
level=CompilationLevel.PIECEWISE, mode=CompilationMode.VLLM_COMPILE,
# inductor graph partition uses # inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops # torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True, use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE, cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]), compile_sizes=[1, 2],
model) for model in models_list(all=False) ),
*model_info,
)
for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev") if is_torch_equal_or_newer("2.9.0.dev")
]) ],
)
# only test some of the models # only test some of the models
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_custom_compile_config( def test_custom_compile_config(
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]], model: str,
model_kwargs: dict[str, Any],
): ):
if (compilation_config.use_inductor_graph_partition if (
and not is_torch_equal_or_newer("2.9.0.dev")): "w8a8" in model
pytest.skip("inductor graph partition is only available " or "w8w8" in model
"in PyTorch 2.9+") and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs) run_model(compilation_config, model, **model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm): @pytest.mark.parametrize(
if not is_torch_equal_or_newer("2.9.0.dev"): "compilation_mode",
pytest.skip("inductor graph partition is only available " [CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
"in PyTorch 2.9+") )
@pytest.mark.parametrize(
"model, backend",
[
("Qwen/Qwen2-0.5B", None), # Standard attention model
(
"deepseek-ai/DeepSeek-V2-Lite",
AttentionBackendEnum.FLASHINFER_MLA,
), # MLA (Multi-head Latent Attention) model
],
)
def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int,
model: str,
backend: AttentionBackendEnum | None,
):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
)
model_kwargs = { model_kwargs = {
"kv_cache_dtype": "fp8", "quantization": "fp8",
"max_model_len": 1024, "kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True,
"max_model_len": 512,
} }
with caplog_vllm.at_level( run_model(compilation_mode, model, **model_kwargs)
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
run_model(compilation_config, model, model_kwargs) def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
try: compile_config
assert ("Fused quantization onto 48 attention nodes" if isinstance(compile_config, CompilationConfig)
in caplog_vllm.text), caplog_vllm.text else CompilationConfig(mode=compile_config)
except AssertionError: )
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -183,12 +226,17 @@ def run_model(compile_config: Union[int, CompilationConfig], model: str, ...@@ -183,12 +226,17 @@ def run_model(compile_config: Union[int, CompilationConfig], model: str,
"The future of AI is", "The future of AI is",
] ]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM( llm = LLM(
model=model, model=model,
enforce_eager=True, compilation_config=compilation_config,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
**model_kwargs, **model_kwargs,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -5,114 +5,262 @@ import pytest ...@@ -5,114 +5,262 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.config import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( CompilationConfig,
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from .backend import TestBackend from .backend import TestBackend
OPS_IN_MODEL = [ TEST_FP8 = current_platform.supports_fp8()
torch.ops._C.rotary_embedding.default, FP8_DTYPE = current_platform.fp8_dtype()
torch.ops._C.fused_add_rms_norm.default,
]
class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
else:
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
return (torch.rand(num_tokens, hidden_size * 2),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.silu_and_mul_quant.default]
else:
return [torch.ops._C.silu_and_mul.default]
def ops_not_in_model(self):
return []
class TestFusedAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# layer normalization
norm_output, residual_output = self.norm(mm, residual)
if TEST_FP8:
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output
else:
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else:
return [torch.ops._C.fused_add_rms_norm.default]
def ops_not_in_model(self):
return []
RMS_OP = torch.ops._C.rms_norm.default
RMS_QUANT_OPS = { class TestRotaryEmbedding(torch.nn.Module):
"static_fp8": [ def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
torch.ops._C.rms_norm_static_fp8_quant.default, super().__init__()
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default self.head_dim = head_dim
], self.rotary_dim = rotary_dim or head_dim
}
SILU_MUL_OP = torch.ops._C.silu_and_mul.default self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=base,
)
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default def forward(self, positions, q, k):
prompts = [ q_rotated, k_rotated = self.rotary_emb(positions, q, k)
"Hello, my name is", return q_rotated, k_rotated
"The president of the United States is",
"The capital of France is", def example_inputs(self, num_tokens=32, head_dim=64):
"The future of AI is", positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim)
return (positions, q, k)
def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]
def ops_not_in_model(self):
return []
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=base,
)
def forward(self, positions, hidden_states):
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
# -> slice_scatter -> split_with_sizes
qkv = self.qkv_proj(hidden_states)
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
q, k, v = torch.split(qkv, split_sizes, dim=-1)
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states)
def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]
def ops_not_in_model(self):
return [torch.ops.aten.slice_scatter.default]
MODELS = [
TestSiluMul,
TestFusedAddRMSNorm,
TestRotaryEmbedding,
TestRotaryEmbeddingSliceScatter,
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
"model, quant_key", @pytest.mark.parametrize("model_class", MODELS)
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
kFp8DynamicTokenSym)])
@pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
reason="Only test on CUDA") def test_fix_functionalization(
def test_fix_functionalization(model: str, quant_key: QuantKey, model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
do_fusion: bool): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
)
with set_current_vllm_config(vllm_config):
assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
vllm_config = VllmConfig() # make sure the ops were all de-functionalized
vllm_config.compilation_config = CompilationConfig( found = dict()
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) for node in backend_func.graph_post_pass.nodes:
noop_pass = NoOpEliminationPass(vllm_config) for op in model.ops_in_model(do_fusion):
fusion_pass = RMSNormQuantFusionPass(vllm_config) if is_func(node, op):
cleanup_pass = PostCleanupPass(vllm_config) found[op] = True
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) for op in model.ops_not_in_model():
if is_func(node, op):
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass found[op] = True
] if do_fusion else [noop_pass, cleanup_pass] assert all(found[op] for op in model.ops_in_model(do_fusion))
func_pass = FixFunctionalizationPass(vllm_config) assert all(not found.get(op) for op in model.ops_not_in_model())
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm = LLM(model=model, enforce_eager=True)
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
orig_model = model_runner.model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_func)
gen_func = llm.generate(prompts, sampling_params)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in ops:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in ops)
...@@ -5,17 +5,29 @@ import pytest ...@@ -5,17 +5,29 @@ import pytest
import torch import torch
import vllm.plugins import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
RMSNormQuantFusionPass) from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, from vllm.config import (
VllmConfig) CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc) GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported from ..utils import override_cutlass_fp8_supported
...@@ -23,25 +35,34 @@ from .backend import TestBackend ...@@ -23,25 +35,34 @@ from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool, class TestModel(torch.nn.Module):
cuda_force_torch: bool, *args, **kwargs): def __init__(
self,
hidden_size: int,
eps: float,
static: bool,
cuda_force_torch: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
quant_scale = ScaleDesc(torch.float32, static, group_shape) quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static: if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else: else:
self.scale = [None for _ in range(2)] self.scale = [None for _ in range(3)]
self.w = [ self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2) for _ in range(3)
] ]
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
...@@ -50,57 +71,97 @@ class TestModel(torch.nn.Module): ...@@ -50,57 +71,97 @@ class TestModel(torch.nn.Module):
act_quant_group_shape=group_shape, act_quant_group_shape=group_shape,
) )
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x): def forward(self, x):
resid = torch.sqrt(x) # avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply(y, x2 = self.fp8_linear.apply(
self.w[0], y, self.w[0], self.wscale[0], input_scale=self.scale[0]
self.wscale[0], )
input_scale=self.scale[0])
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2, x3 = self.fp8_linear.apply(
self.w[1], y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
self.wscale[1], )
input_scale=self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3
def ops_in_model_before(self): x4 = self.fp8_linear.apply(
return [QUANT_OPS[self.key]] y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)], FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)] FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
] ]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that # cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True. # cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch", @pytest.mark.parametrize(
[True, False] if cutlass_fp8_supported() else [True]) "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
@pytest.mark.skipif(not current_platform.is_cuda_alike(), )
reason="Only test on CUDA and ROCm") @pytest.mark.skipif(
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
cuda_force_torch): )
def test_fusion_rmsnorm_quant(
dtype,
hidden_size,
num_tokens,
eps,
static,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig( custom_ops = []
level=CompilationLevel.PIECEWISE, if enable_rms_norm_custom_op:
custom_ops=["+rms_norm", "+quant_fp8"], custom_ops.append("+rms_norm")
pass_config=PassConfig(enable_fusion=True, enable_noop=True), if enable_quant_fp8_custom_op:
)) custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
...@@ -108,31 +169,39 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, ...@@ -108,31 +169,39 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch) model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
result = model(x) model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model2 = torch.compile(model, backend=backend) model_unfused = torch.compile(model, backend=backend2)
result2 = model2(x) result_unfused = model_unfused(x)
# Higher tol for dynamic, even higher for bfloat16 if dtype == torch.float16:
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3) ATOL, RTOL = (2e-3, 2e-3)
else: else:
ATOL, RTOL = (1e-2, 1e-2) ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2 assert fusion_pass.matched_count == 3
# In pre-nodes, fp8 quant should be there and fused kernels should not
backend.check_before_ops(model.ops_in_model_before()) backend.check_before_ops(model.ops_in_model_before())
backend.check_before_ops(
# In post-nodes, fused kernels should be there and fp8 quant should not model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after()) backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2
...@@ -6,59 +6,66 @@ import pytest ...@@ -6,59 +6,66 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, from vllm.config import (
ModelConfig, PassConfig, VllmConfig) CompilationConfig,
CompilationMode,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.distributed.parallel_state import (
initialize_model_parallel) init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8) Fp8LinearOp,
GroupShape,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from ..utils import has_module_attribute, multi_gpu_test from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual): def forward(self, x):
view = hidden_states.reshape(-1, self.hidden_size) # avoid having graph input be an arg to a pattern directly
all_reduce = tensor_model_parallel_all_reduce(view) z = torch.relu(x)
norm = self.norm(all_reduce) x = resid = tensor_model_parallel_all_reduce(z)
return norm y = self.norm[0](x)
def ops_in_model_before(self): z2 = torch.mm(y, self.w[0])
return [torch.ops.vllm.all_reduce.default] x2 = tensor_model_parallel_all_reduce(z2)
def ops_in_model_after(self): y2, resid = self.norm[1](x2, resid)
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): y3, resid = self.norm[2](x3, resid)
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): z4 = torch.mm(y3, self.w[2])
super().__init__() x4 = tensor_model_parallel_all_reduce(z4)
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
def forward(self, hidden_states, residual): y4, resid = self.norm[3](x4, resid)
view = hidden_states.reshape(-1, self.hidden_size) return y4
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm
def ops_in_model_before(self): def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default] return [torch.ops.vllm.all_reduce.default]
...@@ -67,27 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): ...@@ -67,27 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.quant_fp8 = QuantFP8(static=True, self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape=GroupShape.PER_TENSOR) self.w = [
self.scale = torch.rand(1, dtype=torch.float32) torch.rand(hidden_size, hidden_size)
self.output = torch.empty((token_num, hidden_size), .to(dtype=current_platform.fp8_dtype())
dtype=torch.float32) .t()
for _ in range(3)
def forward(self, hidden_states, residual): ]
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view) self.fp8_linear = Fp8LinearOp(
norm_output, residual_output = self.norm(all_reduce, residual) act_quant_static=True,
torch.ops._C.static_scaled_fp8_quant(self.output, act_quant_group_shape=GroupShape.PER_TENSOR,
norm_output.contiguous(), )
self.scale)
return self.output, residual_output self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
...@@ -96,35 +129,58 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): ...@@ -96,35 +129,58 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
else torch.ops.aten.reciprocal.default,
] ]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = RMSNorm(hidden_size, eps) self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
dtype=torch.float32) self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
round_up = lambda x, y: (x + y - 1) // y * y self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16 wq_gen, wscale_gen = zip(
rounded_n = round_up(scale_n, 4) *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
self.output_scale = torch.empty((rounded_m, rounded_n // 4), )
dtype=torch.int32) self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size) def forward(self, hidden_states):
all_reduce = tensor_model_parallel_all_reduce(view) # avoid having graph input be an arg to a pattern directly
norm_output, residual_output = self.norm(all_reduce, residual) z = torch.relu(hidden_states)
norm_output = norm_output.reshape(-1, norm_output.shape[-1]) x = resid = tensor_model_parallel_all_reduce(z)
torch.ops._C.scaled_fp4_quant(self.output, norm_output, y = self.norm[0](x)
self.output_scale, self.scale)
return self.output, residual_output, self.output_scale yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
z2 = cutlass_scaled_fp4_mm(
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
z3 = cutlass_scaled_fp4_mm(
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
z4 = cutlass_scaled_fp4_mm(
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self): def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
...@@ -132,54 +188,81 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): ...@@ -132,54 +188,81 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self): def ops_in_model_before(self):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default torch.ops._C.scaled_fp4_quant.default,
] ]
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_model", "test_model, enable_quant_fp8_custom_op",
[ [
TestAllReduceRMSNormModel, (TestAllReduceRMSNormModel, False),
TestAllReduceFusedAddRMSNormModel, (TestAllReduceRMSNormStaticQuantFP8Model, True),
TestAllReduceFusedAddRMSNormStaticQuantFP8Model, (TestAllReduceRMSNormStaticQuantFP8Model, False),
# TODO: Enable with torch==2.8.0 (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model, ],
]) )
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif( @pytest.mark.skipif(
not find_spec("flashinfer") not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer " reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion") "is not compiled with trtllm_allreduce_fusion",
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, )
batch_size: int, seq_len: int, def test_all_reduce_fusion_pass_replace(
hidden_size: int, dtype: torch.dtype): test_model: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
num_processes = 2 num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model if (
and not current_platform.has_device_capability(100)): test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
pytest.skip("Skip as nvfp4 is only supported on " and not current_platform.has_device_capability(100)
"devices with compute capability 10.0 (Blackwell)") ):
pytest.skip(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
def run_torch_spawn(fn, nprocs): def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn, torch.multiprocessing.spawn(
args=(num_processes, test_model, fn,
batch_size, seq_len, hidden_size, args=(
dtype), num_processes,
nprocs=nprocs) test_model,
batch_size,
seq_len,
hidden_size,
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
),
nprocs=nprocs,
)
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, def all_reduce_fusion_pass_on_test_model(
test_model_cls: torch.nn.Module, local_rank: int,
batch_size: int, seq_len: int, world_size: int,
hidden_size: int, dtype: torch.dtype): test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
current_platform.seed_everything(0) current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
...@@ -187,50 +270,63 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, ...@@ -187,50 +270,63 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
update_environment_variables({ update_environment_variables(
'RANK': str(local_rank), {
'LOCAL_RANK': str(local_rank), "RANK": str(local_rank),
'WORLD_SIZE': str(world_size), "LOCAL_RANK": str(local_rank),
'MASTER_ADDR': 'localhost', "WORLD_SIZE": str(world_size),
'MASTER_PORT': '12345', "MASTER_ADDR": "localhost",
}) "MASTER_PORT": "12345",
}
)
init_distributed_environment() init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig(compilation_config=CompilationConfig( custom_ops = []
level=CompilationLevel.PIECEWISE, if enable_rms_norm_custom_op:
custom_ops=["+rms_norm", "+quant_fp8"])) custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
)
)
vllm_config.compilation_config.pass_config = PassConfig( vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True) enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
# in the vllm_config, it's not really used. # in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(model=model_name, vllm_config.model_config = ModelConfig(
trust_remote_code=True, model=model_name, trust_remote_code=True, dtype=dtype, seed=42
dtype=dtype, )
seed=42) with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
cleanup_pass) )
token_num = batch_size * seq_len token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num) model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 4, (
assert all_reduce_fusion_pass.matched_count == 1 f"{all_reduce_fusion_pass.matched_count=}"
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) )
backend.check_after_ops(model.ops_in_model_after()) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
del all_reduce_fusion_pass backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment