Unverified Commit 1e3e76b6 authored by pyc96's avatar pyc96 Committed by GitHub
Browse files

[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due...


[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)
Signed-off-by: default avatarpyc96 <pychen96@gmail.com>
parent 53ea6ad8
......@@ -302,6 +302,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
outputs.append(output)
if self.return_hidden_states and is_fallback:
if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment