Unverified Commit 28873a27 authored by Aman Gupta Karmani's avatar Aman Gupta Karmani Committed by GitHub
Browse files

Improve _prune_hidden_states micro-benchmark (#707)

parent 0080d832
......@@ -100,7 +100,8 @@ def _prune_hidden_states(
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
return hidden_states.index_select(
0, torch.tensor(last_token_indicies, device=hidden_states.device))
def _get_penalties(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment