Unverified Commit 2110557d authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix cuda graph for MLPSpeculator (#5875)


Co-authored-by: default avatarAbhinav Goyal <abhinav.goyal@flipkart.com>
parent b9e84259
...@@ -52,7 +52,6 @@ if __name__ == "__main__": ...@@ -52,7 +52,6 @@ if __name__ == "__main__":
speculative_model="ibm-fms/llama-13b-accelerator", speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding # These are currently required for MLPSpeculator decoding
use_v2_block_manager=True, use_v2_block_manager=True,
enforce_eager=True,
) )
print("With speculation") print("With speculation")
......
...@@ -1020,10 +1020,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1020,10 +1020,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
if model_input.is_prompt:
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
hidden_states = hidden_states.index_select( indices = model_input.sampling_metadata.selected_token_indices
0, model_input.sampling_metadata.selected_token_indices) if model_input.is_prompt:
hidden_states = hidden_states.index_select(0, indices)
elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)]
output.hidden_states = hidden_states output.hidden_states = hidden_states
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment