Unverified Commit 32e45636 authored by Xinan Miao's avatar Xinan Miao Committed by GitHub
Browse files

[torch.compile]: Disable Sequence Parallelism (SP) for piecewise compilation (#38373)


Signed-off-by: default avatarSouthWest7 <am1ao@qq.com>
Signed-off-by: default avatarXinan Miao <1403572259@qq.com>
Co-authored-by: default avatarSouthWest7 <am1ao@qq.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarOpenAI Codex <codex@openai.com>
Co-authored-by: default avatarWang Xingran <72983099+wangxingran222@users.noreply.github.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent b39c266d
...@@ -261,6 +261,8 @@ def _compare_sp( ...@@ -261,6 +261,8 @@ def _compare_sp(
}, },
"use_inductor_graph_partition": use_inductor_graph_partition, "use_inductor_graph_partition": use_inductor_graph_partition,
} }
if not use_inductor_graph_partition:
compilation_config["splitting_ops"] = []
tp_sp_args = [ tp_sp_args = [
*common_args, *common_args,
......
...@@ -19,6 +19,7 @@ from vllm.config import ( ...@@ -19,6 +19,7 @@ from vllm.config import (
VllmConfig, VllmConfig,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.config.utils import Range
from vllm.distributed import ( from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter, tensor_model_parallel_reduce_scatter,
...@@ -288,6 +289,22 @@ def test_async_tp_pass_replace( ...@@ -288,6 +289,22 @@ def test_async_tp_pass_replace(
run_torch_spawn(async_tp_pass_on_test_model, num_processes) run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def test_async_tp_pass_requires_full_graph_compilation():
vllm_config = VllmConfig()
vllm_config.compilation_config.use_inductor_graph_partition = False
vllm_config.compilation_config.splitting_ops = [
"vllm::unified_attention_with_output"
]
async_tp_pass = object.__new__(AsyncTPPass)
async_tp_pass.compilation_config = vllm_config.compilation_config
with pytest.raises(
AssertionError, match="AsyncTPPass requires full-graph compilation"
):
async_tp_pass.is_applicable_for_range(Range(start=8, end=8))
def async_tp_pass_on_test_model( def async_tp_pass_on_test_model(
local_rank: int, local_rank: int,
world_size: int, world_size: int,
......
...@@ -22,6 +22,7 @@ from vllm.config import ( ...@@ -22,6 +22,7 @@ from vllm.config import (
get_current_vllm_config, get_current_vllm_config,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.config.utils import Range
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
init_distributed_environment, init_distributed_environment,
...@@ -216,6 +217,24 @@ def test_sequence_parallelism_pass( ...@@ -216,6 +217,24 @@ def test_sequence_parallelism_pass(
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def test_sequence_parallelism_pass_requires_full_graph_compilation():
vllm_config = VllmConfig()
vllm_config.compilation_config.use_inductor_graph_partition = False
vllm_config.compilation_config.splitting_ops = [
"vllm::unified_attention_with_output"
]
sequence_parallelism_pass = object.__new__(SequenceParallelismPass)
sequence_parallelism_pass.compilation_config = vllm_config.compilation_config
sequence_parallelism_pass.min_token_num = 1
with pytest.raises(
AssertionError,
match="SequenceParallelismPass requires full-graph compilation",
):
sequence_parallelism_pass.is_applicable_for_range(Range(start=8, end=8))
def sequence_parallelism_pass_on_test_model( def sequence_parallelism_pass_on_test_model(
local_rank: int, local_rank: int,
world_size: int, world_size: int,
......
...@@ -407,7 +407,7 @@ def test_should_split(): ...@@ -407,7 +407,7 @@ def test_should_split():
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256), (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list # max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15), ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP # SP forces full-graph compilation, sizes are filtered by TP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4), ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens # limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4), ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
...@@ -465,6 +465,123 @@ def test_cudagraph_sizes_post_init( ...@@ -465,6 +465,123 @@ def test_cudagraph_sizes_post_init(
) )
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_mode",
"use_inductor_graph_partition",
"expected_enable_sp",
"expected_cudagraph_mode",
"expected_piecewise_compile",
"expected_capture_sizes",
"expected_max_size",
),
[
(CUDAGraphMode.PIECEWISE, False, True, CUDAGraphMode.FULL, False, [2, 4], 4),
(
CUDAGraphMode.FULL_DECODE_ONLY,
False,
True,
CUDAGraphMode.FULL_DECODE_ONLY,
False,
[2, 4],
4,
),
(
CUDAGraphMode.FULL_AND_PIECEWISE,
False,
True,
CUDAGraphMode.FULL,
False,
[2, 4],
4,
),
(
CUDAGraphMode.FULL_AND_PIECEWISE,
True,
True,
CUDAGraphMode.FULL_AND_PIECEWISE,
True,
[2, 4],
4,
),
],
)
def test_sequence_parallelism_requires_full_graph_compilation(
cudagraph_mode: CUDAGraphMode,
use_inductor_graph_partition: bool,
expected_enable_sp: bool,
expected_cudagraph_mode: CUDAGraphMode,
expected_piecewise_compile: bool,
expected_capture_sizes: list[int],
expected_max_size: int,
):
with patch.object(current_platform, "device_count", return_value=2):
vllm_config = VllmConfig(
parallel_config=ParallelConfig(tensor_parallel_size=2),
scheduler_config=SchedulerConfig(
max_num_seqs=128,
max_num_batched_tokens=2048,
max_model_len=2048,
is_encoder_decoder=False,
),
)
vllm_config.model_config = MagicMock(
dtype=torch.float16,
enforce_eager=False,
is_moe=False,
disable_cascade_attn=False,
get_hidden_size=MagicMock(return_value=4096),
)
vllm_config.compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=[1, 2, 4, 15],
max_cudagraph_capture_size=None,
compile_sizes=["cudagraph_capture_sizes"],
use_inductor_graph_partition=use_inductor_graph_partition,
pass_config=PassConfig(
enable_sp=True,
fuse_gemm_comms=True,
fuse_norm_quant=True,
fuse_act_quant=True,
eliminate_noops=True,
sp_min_token_num=512,
),
cudagraph_mode=cudagraph_mode,
)
vllm_config.compilation_config.set_splitting_ops_for_v1(
all2all_backend=vllm_config.parallel_config.all2all_backend,
data_parallel_size=1,
)
vllm_config._set_compile_ranges()
vllm_config._set_cudagraph_sizes()
assert (
vllm_config.compilation_config.use_inductor_graph_partition
== use_inductor_graph_partition
)
assert (
bool(vllm_config.compilation_config.splitting_ops) == expected_piecewise_compile
)
assert vllm_config.compilation_config.pass_config.enable_sp == expected_enable_sp
assert (
vllm_config.compilation_config.pass_config.fuse_gemm_comms == expected_enable_sp
)
assert vllm_config.compilation_config.cudagraph_mode == expected_cudagraph_mode
assert (
vllm_config.compilation_config.cudagraph_capture_sizes == expected_capture_sizes
)
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)
assert (
511 in vllm_config.compilation_config.compile_ranges_endpoints
) == expected_enable_sp
def test_cached_compilation_config(default_vllm_config): def test_cached_compilation_config(default_vllm_config):
import torch import torch
from torch._inductor.utils import run_and_get_code from torch._inductor.utils import run_and_get_code
......
...@@ -406,16 +406,13 @@ class AsyncTPPass(VllmPatternMatcherPass): ...@@ -406,16 +406,13 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool: def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass. # This pass is applied on top of the sequence parallelism pass,
# It inherits the same applicability condition as `SequenceParallelismPass`. # which is only supported in fullgraph compilation mode.
# See `SequenceParallelismPass.is_applicable` for more details. assert (
if ( self.compilation_config.use_inductor_graph_partition
not self.compilation_config.splitting_ops or not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition ), "AsyncTPPass requires full-graph compilation"
):
return True return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: def __call__(self, graph: fx.Graph) -> None:
......
...@@ -341,22 +341,18 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -341,22 +341,18 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
significantly reduce communication overhead and improve overall model significantly reduce communication overhead and improve overall model
performance. performance.
This pass is only supported when compiling the whole graph (fullgraph
This pass splits up the residual tensor across TP ranks and hence divides its size. mode, i.e. using Inductor graph partition or empty splitting_ops).
Because the pattern matcher starts at the end of the graph, the replacement Piecewise compilation is not supported because the residual tensor
contains a slice that temporarily conforms the input residual to the correct size. gets split across TP ranks, causing size mismatches at subgraph
After all patterns have been matched, we use a NoOpEliminationPass to clean up boundaries.
what have now become no-op slices.
This pass splits up the residual tensor across TP ranks and hence
Note that an older version of the pass did not need this as it operated only on divides its size. Because the pattern matcher starts at the end of
custom rms_norm and fused_rms_norm_add custom ops which did not complain about the graph, the replacement contains a slice that temporarily conforms
mismatched shapes during replacement. So this approach has the same assumption that the input residual to the correct size. After all patterns have been
correctness is only maintained if all rms_norm operations are split across ranks. matched, we use a NoOpEliminationPass to clean up what have now
become no-op slices.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
""" """
@enable_fake_mode @enable_fake_mode
...@@ -419,19 +415,13 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -419,19 +415,13 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
and gathering tensors across TP ranks outweighs the benefits. and gathering tensors across TP ranks outweighs the benefits.
Returns False (SP disabled) when: Returns False (SP disabled) when:
- Using piecewise compilation with non-concrete or TP-indivisible sizes
- min_token_num is None (SP disabled for this device/config) - min_token_num is None (SP disabled for this device/config)
- The compile range starts below the minimum token threshold - The compile range starts below the minimum token threshold
""" """
# For piecewise compilation (not using inductor graph partition), assert (
# we need concrete sizes that are divisible by TP for correct splitting self.compilation_config.use_inductor_graph_partition
if ( or not self.compilation_config.splitting_ops
not self.compilation_config.use_inductor_graph_partition ), "SequenceParallelismPass requires full-graph compilation"
and self.compilation_config.splitting_ops
):
tp_size = get_tensor_model_parallel_world_size()
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
return False
# min_token_num is None when SP is disabled for this device/config # min_token_num is None when SP is disabled for this device/config
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size) # (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
......
...@@ -1148,6 +1148,25 @@ class CompilationConfig: ...@@ -1148,6 +1148,25 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = [] self.splitting_ops = []
if (
not self.use_inductor_graph_partition
and (self.pass_config.enable_sp or self.pass_config.fuse_gemm_comms)
and self.splitting_ops
):
logger.warning_once(
"Sequence parallelism requires full-graph compilation when "
"use_inductor_graph_partition is off. Setting splitting_ops "
"to an empty list to preserve SP and async TP."
)
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"Sequence parallelism is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"Setting cudagraph_mode to FULL."
)
self.cudagraph_mode = CUDAGraphMode.FULL
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible # Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
if ( if (
all2all_backend == "deepep_high_throughput" all2all_backend == "deepep_high_throughput"
......
...@@ -983,19 +983,16 @@ class VllmConfig: ...@@ -983,19 +983,16 @@ class VllmConfig:
) )
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism # async tp is built on top of sequence parallelism and requires it.
# and requires it to be enabled. pass_config = self.compilation_config.pass_config
if self.compilation_config.pass_config.fuse_gemm_comms: if pass_config.fuse_gemm_comms:
self.compilation_config.pass_config.enable_sp = True pass_config.enable_sp = True
if self.compilation_config.pass_config.enable_sp: if pass_config.enable_sp:
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
logger.warning("Sequence Parallelism requires TP>1, disabling") logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False pass_config.fuse_gemm_comms = False
else: else:
# Compute SP threshold early; disable if None (model too
# small for SP to be beneficial).
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None: if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import ( from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold, get_sequence_parallelism_threshold,
...@@ -1015,8 +1012,8 @@ class VllmConfig: ...@@ -1015,8 +1012,8 @@ class VllmConfig:
"threshold heuristic, disabling. To force SP, " "threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually." "set pass_config.sp_min_token_num manually."
) )
self.compilation_config.pass_config.enable_sp = False pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False pass_config.fuse_gemm_comms = False
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
...@@ -1098,6 +1095,7 @@ class VllmConfig: ...@@ -1098,6 +1095,7 @@ class VllmConfig:
self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.cudagraph_num_of_warmups = 1
self._set_cudagraph_sizes() self._set_cudagraph_sizes()
else: else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
...@@ -1171,8 +1169,8 @@ class VllmConfig: ...@@ -1171,8 +1169,8 @@ class VllmConfig:
) )
if self.compilation_config.pass_config.enable_sp: if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning, # With pipeline parallelism, native rms norm tracing errors due to
# native rms norm tracing errors due to incorrect residual shape. # incorrect residual shape.
# Use custom rms norm to unblock. In the future, # Use custom rms norm to unblock. In the future,
# the pass will operate on higher-level IR to avoid the issue. # the pass will operate on higher-level IR to avoid the issue.
# TODO: https://github.com/vllm-project/vllm/issues/27894 # TODO: https://github.com/vllm-project/vllm/issues/27894
...@@ -1183,24 +1181,15 @@ class VllmConfig: ...@@ -1183,24 +1181,15 @@ class VllmConfig:
self.compilation_config.mode, self.compilation_config.mode,
) )
is_fullgraph = ( if self.parallel_config.pipeline_parallel_size > 1:
self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops or []) == 0
)
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if "-rms_norm" not in self.compilation_config.custom_ops: if "-rms_norm" not in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.append("+rms_norm") self.compilation_config.custom_ops.append("+rms_norm")
else: else:
regime = (
"Dynamo partition"
if not is_fullgraph
else "pipeline parallelism"
)
logger.warning_once( logger.warning_once(
"Sequence parallelism not supported with " "Sequence parallelism not supported with "
"native rms_norm when using %s, " "native rms_norm when using %s, "
"this will likely lead to an error.", "this will likely lead to an error.",
regime, "pipeline parallelism",
) )
# final check of cudagraph mode after all possible updates # final check of cudagraph mode after all possible updates
...@@ -1212,9 +1201,9 @@ class VllmConfig: ...@@ -1212,9 +1201,9 @@ class VllmConfig:
and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501 and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501
): ):
logger.warning_once( logger.warning_once(
"No piecewise cudagraph for executing cascade attention." "No piecewise cudagraph for executing cascade attention. "
" Will fall back to eager execution if a batch runs " "Will fall back to eager execution if a batch runs into "
"into cascade attentions." "cascade attentions."
) )
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
......
...@@ -519,12 +519,8 @@ def is_residual_scattered_for_sp( ...@@ -519,12 +519,8 @@ def is_residual_scattered_for_sp(
"""Check if the residual tensor is scattered for sequence parallelism. """Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled. parallelism and tensor parallelism is enabled. SP is only supported in
full-graph compilation mode.
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
- In full-graph compilation mode (no splitting ops or using inductor graph
partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes
""" """
if not vllm_config.compilation_config.pass_config.enable_sp: if not vllm_config.compilation_config.pass_config.enable_sp:
return False return False
...@@ -534,16 +530,13 @@ def is_residual_scattered_for_sp( ...@@ -534,16 +530,13 @@ def is_residual_scattered_for_sp(
if tp == 1: if tp == 1:
return False return False
assert (
vllm_config.compilation_config.use_inductor_graph_partition
or not vllm_config.compilation_config.splitting_ops
), "Sequence parallelism requires full-graph compilation"
# When sequence parallelism is enabled, we always pad num_input_tokens # When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier. # to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0 assert num_input_tokens % tp == 0
if (
not vllm_config.compilation_config.splitting_ops
or vllm_config.compilation_config.use_inductor_graph_partition
):
return True return True
compile_sizes = vllm_config.compilation_config.compile_sizes
if compile_sizes is None:
return False
return num_input_tokens in compile_sizes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment