Unverified Commit c5487e2b authored by Matt's avatar Matt Committed by GitHub
Browse files

[Bugfix] Fix potential EAGLE spec decode segfault during graph capture (#32818)


Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
parent 6437ff1f
......@@ -1222,10 +1222,14 @@ class SpecDecodeBaseProposer:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
if use_cudagraphs:
cudagraph_runtime_mode, batch_desc = (
self.cudagraph_dispatcher.dispatch(num_tokens_dp_padded)
)
num_input_tokens = batch_desc.num_tokens
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment