Commit d3fa2342 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Perf] Change default CUDAGraphMode from FULL_AND_PIECEWISE to PIECEWISE

parent 55989b60
......@@ -78,7 +78,7 @@ class CUDAGraphMode(enum.Enum):
return self.has_mode(CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(max(self.value) if not envs.VLLM_USE_PIECEWISE else min(self.value)) if self.separate_routine() else self
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
......
......@@ -694,28 +694,31 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
if not envs.VLLM_USE_PIECEWISE:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
):
logger.info_once(
"Encoder-decoder models do not support %s. "
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
......
......@@ -4109,7 +4109,11 @@ class GPUModelRunner(
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
seq_lens = max_query_len # type: ignore[assignment]
if not envs.VLLM_USE_PIECEWISE:
seq_lens = max_query_len
else:
# Make sure max_model_len is used at the graph capture time.
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
......@@ -4825,35 +4829,36 @@ class GPUModelRunner(
logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition
if not envs.VLLM_USE_PIECEWISE:
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
):
msg += (
"; setting cudagraph_mode=PIECEWISE because "
"attention is compiled piecewise"
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.PIECEWISE
)
else:
msg += (
"; setting cudagraph_mode=NONE because "
"attention is not compiled piecewise"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.NONE
)
logger.warning(msg)
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition
):
msg += (
"; setting cudagraph_mode=PIECEWISE because "
"attention is compiled piecewise"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.PIECEWISE
)
else:
msg += (
"; setting cudagraph_mode=NONE because "
"attention is not compiled piecewise"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.NONE
)
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment