Unverified Commit df78aeef authored by Yizhou's avatar Yizhou Committed by GitHub
Browse files

Refactor: Move CUDA graph dispatch logic earlier (#27382)


Signed-off-by: default avatarYizhou Liu <liu_yizhou@outlook.com>
parent 7df331c6
...@@ -3740,6 +3740,31 @@ class GPUModelRunner( ...@@ -3740,6 +3740,31 @@ class GPUModelRunner(
dp_rank = self.parallel_config.data_parallel_rank dp_rank = self.parallel_config.data_parallel_rank
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
else (CUDAGraphMode.NONE, None)
)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
else:
cudagraph_runtime_mode = _cg_mode
attn_metadata: PerLayerAttnMetadata | None = None attn_metadata: PerLayerAttnMetadata | None = None
# If force_attention is True, we always capture attention. Otherwise, # If force_attention is True, we always capture attention. Otherwise,
...@@ -3814,31 +3839,6 @@ class GPUModelRunner( ...@@ -3814,31 +3839,6 @@ class GPUModelRunner(
num_tokens_after_padding, None, False num_tokens_after_padding, None, False
) )
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(
num_tokens=num_tokens_after_padding,
uniform_decode=uniform_decode,
has_lora=activate_lora and self.lora_config is not None,
)
)
if not is_profile
else (CUDAGraphMode.NONE, None)
)
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert (
cudagraph_runtime_mode == CUDAGraphMode.NONE
or cudagraph_runtime_mode == _cg_mode
), (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
)
else:
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None: if ubatch_slices is not None:
# Adjust values to reflect a single ubatch. # Adjust values to reflect a single ubatch.
# TODO(sage,lucas): this is cruft that should be addressed in # TODO(sage,lucas): this is cruft that should be addressed in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment