Unverified Commit b59dd19b authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

[compile] Enable sequence parallelism for full cuda graph without specifying compile sizes (#26681)


Signed-off-by: default avatarangelayi <yiangela7@gmail.com>
parent 3e051bda
...@@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass): ...@@ -431,8 +431,15 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: int | None) -> bool: def is_applicable(self, shape: int | None) -> bool:
# only do replace for specific shapes # This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return shape is not None and shape % tp_size == 0
......
...@@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass): ...@@ -96,7 +96,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest() return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_shape(self, shape: int | None): def is_applicable(self, shape: int | None):
return True return True
......
...@@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass): ...@@ -71,9 +71,11 @@ class PostGradPassManager(CustomGraphPass):
shape = get_pass_context().runtime_shape shape = get_pass_context().runtime_shape
for pass_ in self.passes: for pass_ in self.passes:
if pass_.is_applicable_for_shape(shape): if pass_.is_applicable(shape):
pass_(graph) pass_(graph)
VllmInductorPass.dump_prefix += 1 VllmInductorPass.dump_prefix += 1
else:
logger.debug("Skipping %s with shape %s", pass_, shape)
# post-cleanup goes before fix_functionalization # post-cleanup goes before fix_functionalization
# because it requires a functional graph # because it requires a functional graph
......
...@@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -482,7 +482,25 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
).register(self.patterns) ).register(self.patterns)
self.dump_patterns(config, self.patterns) self.dump_patterns(config, self.patterns)
def is_applicable_for_shape(self, shape: int | None) -> bool: def is_applicable(self, shape: int | None) -> bool:
# When sequence parallelism is enabled, the residual tensor from RMSNorm
# needs to be split along the sequence dimension. However, this dimension
# is symbolic during piecewise compilation, and splitting symbolic shapes
# is not supported.
#
# This pass is therefore only applied when the sequence dimension is
# concrete:
# 1. In full-graph compilation mode (no Dynamo splitting ops are used).
# For this case we always pad num_tokens to be a multiple of
# tensor_parallel_size, so there's no need to check shape % tp_size == 0.
# 2. For specific shape provided during compilation (e.g., from
# `compile_sizes`), which must be divisible by the tensor-parallel
# size.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return shape is not None and shape % tp_size == 0 return shape is not None and shape % tp_size == 0
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
import operator import operator
import time import time
import weakref
from typing import ClassVar from typing import ClassVar
import regex as re import regex as re
...@@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass): ...@@ -28,6 +29,7 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering.""" """Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config)
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None self.device = config.device_config.device if config.device_config else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment