Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
import tempfile
from typing import Any, Optional, Union
from pathlib import Path
from typing import Any
import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
}),
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
"dtype": torch.float16,
}),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{"dtype": torch.float16},
),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
if all:
TEST_MODELS.extend(
[
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{"dtype": torch.float16},
),
]
)
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"})
)
if is_quant_method_supported("gptq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {
"quantization": "gptq"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"})
)
if is_quant_method_supported("gptq_marlin"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", {
"quantization": "gptq_marlin"
}))
TEST_MODELS.append(
(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ",
{"quantization": "gptq_marlin"},
)
)
if is_quant_method_supported("gptq_marlin_24"):
TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", {
"quantization": "gptq_marlin_24"
}))
TEST_MODELS.append(
(
"alexm-nm/tinyllama-24-marlin24-4bit-g128",
{"quantization": "gptq_marlin_24"},
)
)
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
"quantization": "AWQ"
}))
TEST_MODELS.append(
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"})
)
if keywords is None:
return TEST_MODELS
......@@ -72,110 +80,145 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
"compilation_mode",
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize("model_info", models_list(all=True))
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
@create_new_process_for_each_test()
def test_full_graph(
monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]],
optimization_level: int,
model: str,
model_kwargs: dict[str, Any],
compilation_mode: int,
):
model, model_kwargs = model_info
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
with monkeypatch.context():
print(f"MODEL={model}")
run_model(optimization_level, model, model_kwargs)
run_model(compilation_mode, model, **model_kwargs)
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",
"compilation_config, model, model_kwargs",
[
# additional compile sizes, only some of the models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
compile_sizes=[1, 2]), model)
for model in models_list(all=False)
] + [
(
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
*model_info,
)
for model_info in models_list(all=False)
]
+ [
# RMSNorm + quant fusion, only 8-bit quant models
(CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True,
enable_noop=True)), model)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
] + [
(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
*model_info,
)
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]
+ [
# Test depyf integration works
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
] + [
(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
debug_dump_path=Path(tempfile.gettempdir()),
),
"facebook/opt-125m",
{},
),
]
+ [
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
mode=CompilationMode.VLLM_COMPILE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
compile_sizes=[1, 2],
),
*model_info,
)
for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
])
],
)
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
model: str,
model_kwargs: dict[str, Any],
):
if (compilation_config.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
run_model(compilation_config, model, **model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
@pytest.mark.parametrize(
"compilation_mode",
[CompilationMode.NONE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize(
"model, backend",
[
("Qwen/Qwen2-0.5B", None), # Standard attention model
(
"deepseek-ai/DeepSeek-V2-Lite",
AttentionBackendEnum.FLASHINFER_MLA,
), # MLA (Multi-head Latent Attention) model
],
)
def test_fp8_kv_scale_compile(
monkeypatch: pytest.MonkeyPatch,
compilation_mode: int,
model: str,
backend: AttentionBackendEnum | None,
):
if backend:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True,
"max_model_len": 512,
}
with caplog_vllm.at_level(
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
run_model(compilation_config, model, model_kwargs)
try:
assert ("Fused quantization onto 48 attention nodes"
in caplog_vllm.text), caplog_vllm.text
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
run_model(compilation_mode, model, **model_kwargs)
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(mode=compile_config)
)
prompts = [
"Hello, my name is",
"The president of the United States is",
......@@ -183,12 +226,17 @@ def run_model(compile_config: Union[int, CompilationConfig], model: str,
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
......
......@@ -5,114 +5,262 @@ import pytest
import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from .backend import TestBackend
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
]
TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
else:
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
return (torch.rand(num_tokens, hidden_size * 2),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.silu_and_mul_quant.default]
else:
return [torch.ops._C.silu_and_mul.default]
def ops_not_in_model(self):
return []
class TestFusedAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# layer normalization
norm_output, residual_output = self.norm(mm, residual)
if TEST_FP8:
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output
else:
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else:
return [torch.ops._C.fused_add_rms_norm.default]
def ops_not_in_model(self):
return []
RMS_OP = torch.ops._C.rms_norm.default
RMS_QUANT_OPS = {
"static_fp8": [
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
],
}
class TestRotaryEmbedding(torch.nn.Module):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position,
base=base,
)
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
def forward(self, positions, q, k):
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
return q_rotated, k_rotated
def example_inputs(self, num_tokens=32, head_dim=64):
positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim)
return (positions, q, k)
def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]
def ops_not_in_model(self):
return []
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=base,
)
def forward(self, positions, hidden_states):
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
# -> slice_scatter -> split_with_sizes
qkv = self.qkv_proj(hidden_states)
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
q, k, v = torch.split(qkv, split_sizes, dim=-1)
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states)
def ops_in_model(self, do_fusion):
return [torch.ops._C.rotary_embedding.default]
def ops_not_in_model(self):
return [torch.ops.aten.slice_scatter.default]
MODELS = [
TestSiluMul,
TestFusedAddRMSNorm,
TestRotaryEmbedding,
TestRotaryEmbeddingSliceScatter,
]
@pytest.mark.parametrize(
"model, quant_key",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
kFp8DynamicTokenSym)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
)
with set_current_vllm_config(vllm_config):
assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm = LLM(model=model, enforce_eager=True)
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
orig_model = model_runner.model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_func)
gen_func = llm.generate(prompts, sampling_params)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in ops:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in ops)
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
......@@ -5,17 +5,29 @@ import pytest
import torch
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
RMSNormQuantFusionPass)
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
GroupShape,
QuantKey,
ScaleDesc,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity)
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
......@@ -23,25 +35,34 @@ from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool,
cuda_force_torch: bool, *args, **kwargs):
class TestModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
eps: float,
static: bool,
cuda_force_torch: bool,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(2)]
self.scale = [None for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
for _ in range(3)
]
with override_cutlass_fp8_supported(not cuda_force_torch):
......@@ -50,57 +71,97 @@ class TestModel(torch.nn.Module):
act_quant_group_shape=group_shape,
)
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
resid = torch.sqrt(x)
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])
x2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply(y2,
self.w[1],
self.wscale[1],
input_scale=self.scale[1])
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
def ops_in_model_before(self):
return [QUANT_OPS[self.key]]
x4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
dtype,
hidden_size,
num_tokens,
eps,
static,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
......@@ -108,31 +169,39 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2
......@@ -6,59 +6,66 @@ import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.config import (
CompilationConfig,
CompilationMode,
DeviceConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8)
Fp8LinearOp,
GroupShape,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from vllm.utils.system_utils import update_environment_variables
from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm = self.norm(all_reduce)
return norm
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
y3, resid = self.norm[2](x3, resid)
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
......@@ -67,27 +74,53 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
......@@ -96,35 +129,58 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
else torch.ops.aten.reciprocal.default,
]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
dtype=torch.int32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
self.output_scale, self.scale)
return self.output, residual_output, self.output_scale
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
wq_gen, wscale_gen = zip(
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
z2 = cutlass_scaled_fp4_mm(
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
z3 = cutlass_scaled_fp4_mm(
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
z4 = cutlass_scaled_fp4_mm(
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
......@@ -132,54 +188,81 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
torch.ops._C.scaled_fp4_quant.default,
]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model",
"test_model, enable_quant_fp8_custom_op",
[
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
"is not compiled with trtllm_allreduce_fusion",
)
def test_all_reduce_fusion_pass_replace(
test_model: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")
if (
test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)
):
pytest.skip(
"Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)"
)
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
),
nprocs=nprocs,
)
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
def all_reduce_fusion_pass_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
current_platform.seed_everything(0)
device = torch.device(f"cuda:{local_rank}")
......@@ -187,50 +270,63 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True)
enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
trust_remote_code=True,
dtype=dtype,
seed=42)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment