Unverified Commit 3bbff9e5 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

Fix 1D query issue from `_prune_hidden_states` (#3539)

parent 6ebd02bd
...@@ -77,7 +77,6 @@ def _prune_hidden_states( ...@@ -77,7 +77,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, return hidden_states.index_select(0,
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment