Unverified Commit f36292db authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: default avatarProExpertProg <lgovedic@redhat.com>
Co-authored-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <luka.govedic@gmail.com>
parent 173b356a
...@@ -478,10 +478,11 @@ steps: ...@@ -478,10 +478,11 @@ steps:
- vllm/ - vllm/
- tests/compile - tests/compile
commands: commands:
# fp8 kv scales not supported on sm89, tested on Blackwell instead
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile' - pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
# Limit to no custom ops to reduce running time # Limit to no custom ops to reduce running time
# Wrap with quotes to escape yaml and avoid starting -k string with a - # Wrap with quotes to escape yaml and avoid starting -k string with a -
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'" - "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
- label: Cudagraph test - label: Cudagraph test
timeout_in_minutes: 20 timeout_in_minutes: 20
...@@ -925,7 +926,7 @@ steps: ...@@ -925,7 +926,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min - label: Blackwell Fusion & Compile Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
gpu: b200 gpu: b200
...@@ -946,7 +947,9 @@ steps: ...@@ -946,7 +947,9 @@ steps:
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
# Wrap with quotes to escape yaml # Wrap with quotes to escape yaml
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'" - "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell Fusion E2E Tests # 30 min - label: Blackwell Fusion E2E Tests # 30 min
timeout_in_minutes: 40 timeout_in_minutes: 40
...@@ -969,8 +972,6 @@ steps: ...@@ -969,8 +972,6 @@ steps:
- nvidia-smi - nvidia-smi
# Run all e2e fusion tests # Run all e2e fusion tests
- pytest -v -s tests/compile/test_fusions_e2e.py - pytest -v -s tests/compile/test_fusions_e2e.py
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
- label: Blackwell GPT-OSS Eval - label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60 timeout_in_minutes: 60
...@@ -1266,7 +1267,8 @@ steps: ...@@ -1266,7 +1267,8 @@ steps:
- pytest -v -s tests/compile/test_async_tp.py - pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py - pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
- pytest -v -s tests/distributed/test_sequence_parallel.py
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
- pytest -v -s tests/v1/distributed/test_dbo.py - pytest -v -s tests/v1/distributed/test_dbo.py
......
...@@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer ...@@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test from ..utils import flat_product, multi_gpu_test
is_blackwell = lambda: current_platform.is_device_capability(100)
"""Are we running on Blackwell, a lot of tests depend on it"""
class Matches(NamedTuple):
attention_fusion: int = 0
allreduce_fusion: int = 0
sequence_parallel: int = 0
async_tp: int = 0
class ModelBackendTestCase(NamedTuple): class ModelBackendTestCase(NamedTuple):
model_name: str model_name: str
model_kwargs: dict[str, Any] model_kwargs: dict[str, Any]
backend: AttentionBackendEnum backend: AttentionBackendEnum
attention_fusions: int matches: Matches
allreduce_fusions: int | None = None
MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP8: list[ModelBackendTestCase] = []
...@@ -38,17 +47,33 @@ if current_platform.is_cuda(): ...@@ -38,17 +47,33 @@ if current_platform.is_cuda():
ModelBackendTestCase( ModelBackendTestCase(
# Use smaller model for L40s in CI # Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024), # TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
backend=AttentionBackendEnum.TRITON_ATTN, # so FI attention+fp8_quant is at least tested once
attention_fusions=32, model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
allreduce_fusions=65, backend=AttentionBackendEnum.FLASHINFER
if is_blackwell()
else AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER, # TODO FlashInfer attn broken on Hopper with kvcache=fp8:
attention_fusions=48, # https://github.com/vllm-project/vllm/issues/28568
allreduce_fusions=96, # TODO FlashInfer attn broken on Blackwell for llama4:
# https://github.com/vllm-project/vllm/issues/28604
backend=AttentionBackendEnum.TRITON_ATTN,
matches=Matches(
attention_fusion=48,
allreduce_fusion=96,
sequence_parallel=96,
async_tp=95, # mlp is moe, no fusion there
),
), ),
] ]
...@@ -57,8 +82,12 @@ if current_platform.is_cuda(): ...@@ -57,8 +82,12 @@ if current_platform.is_cuda():
model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=AttentionBackendEnum.FLASHINFER, backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32, matches=Matches(
allreduce_fusions=65, attention_fusion=32,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
] ]
...@@ -68,15 +97,23 @@ if current_platform.is_cuda(): ...@@ -68,15 +97,23 @@ if current_platform.is_cuda():
model_name="meta-llama/Llama-3.1-8B-Instruct", model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, matches=Matches(
allreduce_fusions=65, attention_fusion=0,
allreduce_fusion=65,
sequence_parallel=65,
async_tp=128,
),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="Qwen/Qwen3-30B-A3B", model_name="Qwen/Qwen3-30B-A3B",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0, matches=Matches(
allreduce_fusions=97, attention_fusion=0,
allreduce_fusion=97,
sequence_parallel=97,
async_tp=96, # MLP is MoE, half the fusions of dense
),
), ),
] ]
...@@ -86,19 +123,19 @@ elif current_platform.is_rocm(): ...@@ -86,19 +123,19 @@ elif current_platform.is_rocm():
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.TRITON_ATTN, backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_ATTN, backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
ModelBackendTestCase( ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024), model_kwargs=dict(max_model_len=1024),
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32, matches=Matches(attention_fusion=32),
), ),
] ]
...@@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] ...@@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, model_kwargs, backend, " "model_name, model_kwargs, backend, matches, custom_ops",
"attention_fusions, allreduce_fusions, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl # quant_fp4 only has the custom impl
...@@ -118,15 +154,14 @@ def test_attn_quant( ...@@ -118,15 +154,14 @@ def test_attn_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
attention_fusions: int, matches: Matches,
allreduce_fusions: int,
custom_ops: str, custom_ops: str,
inductor_graph_partition: bool, inductor_graph_partition: bool,
caplog_mp_spawn, caplog_mp_spawn,
monkeypatch, monkeypatch,
): ):
if backend == AttentionBackendEnum.FLASHINFER and ( if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer() not is_blackwell() or not has_flashinfer()
): ):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
...@@ -169,12 +204,12 @@ def test_attn_quant( ...@@ -169,12 +204,12 @@ def test_attn_quant(
with caplog_mp_spawn(logging.DEBUG) as log_holder: with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs) run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall( log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(matches) == 1, log_holder.text assert len(log_matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions assert int(log_matches[0]) == matches.attention_fusion
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"] CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
...@@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: ...@@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name, model_kwargs, backend, " "model_name, model_kwargs, backend, matches, custom_ops",
"attention_fusions, allreduce_fusions, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models # Toggle RMSNorm and QuantFP8 for FP8 models
list( list(
flat_product( flat_product(
...@@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str, model_name: str,
model_kwargs: dict, model_kwargs: dict,
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
attention_fusions: int, matches: Matches,
allreduce_fusions: int,
custom_ops: str, custom_ops: str,
inductor_graph_partition: bool, inductor_graph_partition: bool,
caplog_mp_spawn, caplog_mp_spawn,
...@@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9") pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else [] custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition: if inductor_graph_partition:
...@@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm( ...@@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
run_model( run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
) )
matches = re.findall( log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.allreduce_fusion
assert int(log_matches[1]) == matches.allreduce_fusion
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="sequence parallel only tested on CUDA",
)
def test_tp2_attn_quant_async_tp(
model_name: str,
model_kwargs: dict,
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if is_blackwell():
# TODO: https://github.com/vllm-project/vllm/issues/27893
pytest.skip("Blackwell is not supported for AsyncTP pass")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
if "fp4" in model_name.lower() and not is_blackwell():
pytest.skip("NVFP4 quant requires Blackwell")
if backend == AttentionBackendEnum.FLASHINFER:
if not has_flashinfer():
pytest.skip("FlashInfer backend requires flashinfer installed")
if not is_blackwell():
# FlashInfer attn fusion requires Blackwell
matches = matches._replace(attention_fusion=0)
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_sequence_parallelism=True,
enable_async_tp=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text, log_holder.text,
) )
assert len(matches) == 2, log_holder.text assert len(log_matches) == 2, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion
assert int(log_matches[1]) == matches.attention_fusion
log_matches = re.findall(
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions assert int(log_matches[0]) == matches.sequence_parallel
assert int(matches[1]) == attention_fusions assert int(log_matches[1]) == matches.sequence_parallel
matches = re.findall( log_matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns", r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text, log_holder.text,
) )
assert len(matches) == 2, log_holder.text assert len(log_matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions assert int(log_matches[0]) == matches.async_tp
assert int(matches[1]) == allreduce_fusions assert int(log_matches[1]) == matches.async_tp
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
......
...@@ -5,15 +5,15 @@ import pytest ...@@ -5,15 +5,15 @@ import pytest
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import RMSNormQuantFusionPass from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
CUDAGraphMode,
DeviceConfig, DeviceConfig,
ModelConfig, ModelConfig,
PassConfig, PassConfig,
...@@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import ( ...@@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
...@@ -43,172 +44,157 @@ prompts = [ ...@@ -43,172 +44,157 @@ prompts = [
] ]
class TestModel(torch.nn.Module): class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.eps = eps
self.gate_proj = torch.nn.Parameter( self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
torch.empty((intermediate_size, hidden_size)) self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
)
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states, residual): def forward(self, x):
""" z = torch.relu(x)
Forward pass implementing the operations in the FX graph x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
Args: z2 = torch.mm(y, self.w[0])
hidden_states: Input tensor x2 = tensor_model_parallel_all_reduce(z2)
residual: Residual tensor from previous layer
Returns: y2, resid = self.norm[1](x2, resid)
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication z3 = torch.mm(y2, self.w[1])
permute = self.gate_proj.permute(1, 0) x3 = tensor_model_parallel_all_reduce(z3)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce y3, resid = self.norm[2](x3, resid)
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization z4 = torch.mm(y3, self.w[2])
norm_output, residual_output = self.norm(all_reduce, residual) x4 = tensor_model_parallel_all_reduce(z4)
return norm_output, residual_output y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self): def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default] return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self): def ops_in_model_after(self):
return [ return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default, torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
] ]
def ops_in_model(self): def ops_in_model(self):
return [torch.ops._C.fused_add_rms_norm.default] if RMSNorm.enabled():
return [
torch.ops._C.rms_norm.default,
torch.ops._C.fused_add_rms_norm.default,
]
else:
return []
class TestQuantModel(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter( self.hidden_size = hidden_size
torch.empty((intermediate_size, hidden_size)), requires_grad=False self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
) )
self.norm = RMSNorm(intermediate_size, 1e-05)
# Initialize weights self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
torch.nn.init.normal_(self.gate_proj, std=0.02)
def forward(self, hidden_states):
self.fp8_linear = Fp8LinearOp(act_quant_static=True) # avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
self.scale = torch.rand(1, dtype=torch.float32) x = resid = tensor_model_parallel_all_reduce(z)
# Create a weight that is compatible with torch._scaled_mm, y = self.norm[0](x)
# which expects a column-major layout.
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() z2 = self.fp8_linear.apply(
self.wscale = torch.rand(1, dtype=torch.float32) y, self.w[0], self.wscale[0], input_scale=self.scale[0]
def forward(self, hidden_states, residual):
"""
Forward pass implementing the operations in the FX graph
Args:
hidden_states: Input tensor
residual: Residual tensor from previous layer
Returns:
Tuple containing the output tensor
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
# Tensor parallel all-reduce
all_reduce = tensor_model_parallel_all_reduce(mm)
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
) )
return fp8_linear_result, residual_output x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
def ops_in_model_before(self): z3 = self.fp8_linear.apply(
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
# The following are only removed if fusion happens )
if (
self.vllm_config x3 = tensor_model_parallel_all_reduce(z3)
and self.vllm_config.compilation_config.pass_config.enable_fusion y3, resid = self.norm[2](x3, resid) # use resid here
):
ops_to_remove.extend( z4 = self.fp8_linear.apply(
[ y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
torch.ops._C.fused_add_rms_norm.default, )
torch.ops._C.static_scaled_fp8_quant.default, x4 = tensor_model_parallel_all_reduce(z4)
] y4, resid = self.norm[3](x4, resid) # use resid here
) return y4
return ops_to_remove
def ops_in_model_after(self): def ops_in_model_after(self):
ops_to_add = [ return [
torch.ops.vllm.reduce_scatter.default,
torch.ops.vllm.all_gather.default, torch.ops.vllm.all_gather.default,
torch.ops.vllm.reduce_scatter.default,
]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
] ]
# The following is only added if fusion happens
if (
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
return ops_to_add
def ops_in_model(self): def ops_in_model(self):
if ( if self.vllm_config.compilation_config.pass_config.enable_fusion:
self.vllm_config
and self.vllm_config.compilation_config.pass_config.enable_fusion
):
# If fusion happens, the fused op is the one
# we check for (de)functionalization
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
else: elif RMSNorm.enabled():
# If no fusion, the original ops are checked
return [ return [
torch.ops._C.fused_add_rms_norm.default, torch.ops._C.fused_add_rms_norm.default,
# TODO functionalization pass does not handle this yet
# torch.ops._C.static_scaled_fp8_quant.default,
] ]
elif self.fp8_linear.quant_fp8.enabled():
return [
torch.ops._C.static_scaled_fp8_quant.default,
]
else:
return []
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel]) @pytest.mark.parametrize(
"test_model_cls, custom_ops",
[
(TestAllReduceRMSNormModel, "+rms_norm"),
(TestAllReduceRMSNormModel, "-rms_norm"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
],
)
@pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("enable_fusion", [True, False]) @pytest.mark.parametrize("enable_fusion", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_sequence_parallelism_pass( def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module], test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_fusion: bool, enable_fusion: bool,
dynamic: bool,
): ):
num_processes = 2 num_processes = 2
...@@ -220,11 +206,13 @@ def test_sequence_parallelism_pass( ...@@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
args=( args=(
num_processes, num_processes,
test_model_cls, test_model_cls,
custom_ops,
batch_size, batch_size,
seq_len, seq_len,
hidden_size, hidden_size,
dtype, dtype,
enable_fusion, enable_fusion,
dynamic,
), ),
nprocs=nprocs, nprocs=nprocs,
) )
...@@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model( ...@@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
local_rank: int, local_rank: int,
world_size: int, world_size: int,
test_model_cls: type[torch.nn.Module], test_model_cls: type[torch.nn.Module],
custom_ops: str,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
enable_fusion: bool, enable_fusion: bool,
dynamic: bool,
): ):
current_platform.seed_everything(0) current_platform.seed_everything(0)
...@@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model( ...@@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
custom_ops_list = custom_ops.split(",") if custom_ops else []
compilation_config = CompilationConfig( compilation_config = CompilationConfig(
splitting_ops=[], # avoid automatic rms_norm enablement
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
custom_ops=custom_ops_list,
pass_config=PassConfig( pass_config=PassConfig(
enable_sequence_parallelism=True, enable_sequence_parallelism=True,
enable_fusion=enable_fusion, enable_fusion=enable_fusion,
enable_noop=True, enable_noop=True,
) ),
) # NoOp needed for fusion ) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda")) device_config = DeviceConfig(device=torch.device("cuda"))
...@@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model( ...@@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)
assert ( assert (
sequence_parallelism_pass.compilation_config.splitting_ops sequence_parallelism_pass.compilation_config.splitting_ops
...@@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model( ...@@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
passes_for_backend.append(cleanup_pass) passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend) backend = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2) model = test_model_cls(hidden_size)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func) if dynamic:
compiled_model_no_func(hidden_states, residual) torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual) compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
assert sequence_parallelism_pass.matched_count == 1 assert sequence_parallelism_pass.matched_count == 4
# In pre-nodes, all reduce should be there, # In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not # reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before()) for op in model.ops_in_model_before():
assert backend.op_count(op, before=True) == 4
# In post-nodes, reduce scatter and all gather should be there, # In post-nodes, reduce scatter and all gather should be there,
# all reduce should not # all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after()) for op in model.ops_in_model_after():
assert backend.op_count(op, before=False) == 4
# check if the functionalization pass is applied
for op in model.ops_in_model(): for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op) find_auto_fn(backend.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
...@@ -18,6 +18,7 @@ import pytest ...@@ -18,6 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption from vllm.config.model import RunnerOption
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
...@@ -161,6 +162,7 @@ def _compare_sp( ...@@ -161,6 +162,7 @@ def _compare_sp(
test_options: SPTestOptions, test_options: SPTestOptions,
num_gpus_available: int, num_gpus_available: int,
use_inductor_graph_partition: bool, use_inductor_graph_partition: bool,
enable_async_tp: bool,
*, *,
method: Literal["generate", "encode"], method: Literal["generate", "encode"],
is_multimodal: bool, is_multimodal: bool,
...@@ -244,10 +246,10 @@ def _compare_sp( ...@@ -244,10 +246,10 @@ def _compare_sp(
compilation_config = { compilation_config = {
"mode": CompilationMode.VLLM_COMPILE, "mode": CompilationMode.VLLM_COMPILE,
"custom_ops": ["+rms_norm"],
"compile_sizes": [4, 8], "compile_sizes": [4, 8],
"pass_config": { "pass_config": {
"enable_sequence_parallelism": True, "enable_sequence_parallelism": True,
"enable_async_tp": enable_async_tp,
"enable_fusion": enable_fusion, "enable_fusion": enable_fusion,
"enable_noop": True, "enable_noop": True,
}, },
...@@ -307,6 +309,7 @@ SP_TEST_MODELS = [ ...@@ -307,6 +309,7 @@ SP_TEST_MODELS = [
], ],
) )
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_tp_sp_generation( def test_tp_sp_generation(
model_id: str, model_id: str,
...@@ -316,10 +319,19 @@ def test_tp_sp_generation( ...@@ -316,10 +319,19 @@ def test_tp_sp_generation(
test_options: SPTestOptions, test_options: SPTestOptions,
num_gpus_available, num_gpus_available,
use_inductor_graph_partition: bool, use_inductor_graph_partition: bool,
enable_async_tp: bool,
): ):
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
if (
"fp8" in model_id.lower()
and current_platform.get_device_capability() < (9, 0)
and (not enable_async_tp)
):
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
_compare_sp( _compare_sp(
model_id, model_id,
parallel_setup, parallel_setup,
...@@ -328,6 +340,7 @@ def test_tp_sp_generation( ...@@ -328,6 +340,7 @@ def test_tp_sp_generation(
test_options, test_options,
num_gpus_available, num_gpus_available,
use_inductor_graph_partition, use_inductor_graph_partition,
enable_async_tp=enable_async_tp,
method="generate", method="generate",
is_multimodal=False, is_multimodal=False,
) )
This diff is collapsed.
...@@ -445,8 +445,6 @@ class VllmConfig: ...@@ -445,8 +445,6 @@ class VllmConfig:
# and requires it to be enabled. # and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp: if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True self.compilation_config.pass_config.enable_sequence_parallelism = True
if self.compilation_config.pass_config.enable_sequence_parallelism:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode(): if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default # if cudagraph_mode is not explicitly set by users, set default
...@@ -620,6 +618,32 @@ class VllmConfig: ...@@ -620,6 +618,32 @@ class VllmConfig:
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1() self.compilation_config.set_splitting_ops_for_v1()
if self.compilation_config.pass_config.enable_sequence_parallelism:
# With pipeline parallelism or dynamo partitioning,
# native rms norm tracing errors due to incorrect residual shape.
# Use custom rms norm to unblock. In the future,
# the pass will operate on higher-level IR to avoid the issue.
# TODO: https://github.com/vllm-project/vllm/issues/27894
is_fullgraph = (
self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops) == 0
)
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if "-rms_norm" not in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.append("+rms_norm")
else:
regime = (
"Dynamo partition"
if not is_fullgraph
else "pipeline parallelism"
)
logger.warning_once(
"Sequence parallelism not supported with"
"native rms_norm when using %s, "
"this will likely lead to an error.",
regime,
)
# final check of cudagraph mode after all possible updates # final check of cudagraph mode after all possible updates
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment