Commit 9ff617d7 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Perf] Change default CUDAGraphMode from FULL_AND_PIECEWISE to PIECEWISE

parent fd8764b3
......@@ -61,6 +61,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.platforms import current_platform
from vllm import envs
if TYPE_CHECKING:
from _typeshed import DataclassInstance
......@@ -371,15 +372,19 @@ class VllmConfig:
if self.compilation_config.cudagraph_mode is None:
if envs.VLLM_USE_V1 and self.compilation_config.level \
== CompilationLevel.PIECEWISE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_AND_PIECEWISE
# pooling models and encoder-decoder models
# do not support full cudagraphs
if self.model_config is not None and \
(self.model_config.pooler_config is not None
or self.model_config.is_encoder_decoder):
if not envs.VLLM_USE_PIECEWISE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_AND_PIECEWISE
# pooling models and encoder-decoder models
# do not support full cudagraphs
if self.model_config is not None and \
(self.model_config.pooler_config is not None
or self.model_config.is_encoder_decoder):
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
......
......@@ -14,7 +14,6 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
from vllm import envs
if TYPE_CHECKING:
from vllm.config import VllmConfig
......@@ -57,7 +56,7 @@ class CUDAGraphMode(enum.Enum):
def max_cudagraph_mode(self) -> 'CUDAGraphMode':
return CUDAGraphMode(max(
self.value) if not envs.VLLM_USE_PIECEWISE else min(self.value)) if self.separate_routine() else self
self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
......
......@@ -1657,7 +1657,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
# vLLM will use piecewise
"VLLM_USE_PIECEWISE":
lambda: (os.environ.get("VLLM_USE_PIECEWISE", "False").lower() in
lambda: (os.environ.get("VLLM_USE_PIECEWISE", "True").lower() in
("true", "1")),
# vllm will use encoding_dsv32.py for dpsk-v32
"VLLM_USE_V32_ENCODE":
......
......@@ -3041,7 +3041,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
seq_lens = max_query_len
if not envs.VLLM_USE_PIECEWISE:
seq_lens = max_query_len
else:
# Make sure max_model_len is used at the graph capture time.
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
......@@ -3662,25 +3666,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if (self.compilation_config.level == CompilationLevel.PIECEWISE and
(self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition)):
msg += "; setting cudagraph_mode=PIECEWISE because "\
"attention is compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE because "\
"attention is not compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
if not envs.VLLM_USE_PIECEWISE:
# check that if we are doing decode full-cudagraphs it is supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if (self.compilation_config.level == CompilationLevel.PIECEWISE and
(self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition)):
msg += "; setting cudagraph_mode=PIECEWISE because "\
"attention is compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE because "\
"attention is not compiled piecewise"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment