Commit d3fa2342 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Perf] Change default CUDAGraphMode from FULL_AND_PIECEWISE to PIECEWISE

parent 55989b60
......@@ -78,7 +78,7 @@ class CUDAGraphMode(enum.Enum):
return self.has_mode(CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(max(self.value) if not envs.VLLM_USE_PIECEWISE else min(self.value)) if self.separate_routine() else self
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
def has_full_cudagraphs(self) -> bool:
return self.max_cudagraph_mode() == CUDAGraphMode.FULL
......
......@@ -694,6 +694,7 @@ class VllmConfig:
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
if not envs.VLLM_USE_PIECEWISE:
if (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
and model_config.pooler_config is not None
......@@ -716,6 +717,8 @@ class VllmConfig:
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
......
......@@ -4109,7 +4109,11 @@ class GPUModelRunner(
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
else:
seq_lens = max_query_len # type: ignore[assignment]
if not envs.VLLM_USE_PIECEWISE:
seq_lens = max_query_len
else:
# Make sure max_model_len is used at the graph capture time.
seq_lens = self.max_model_len
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
......@@ -4825,6 +4829,7 @@ class GPUModelRunner(
logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if not envs.VLLM_USE_PIECEWISE:
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment