Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a243095
Unverified
Commit
3a243095
authored
Mar 25, 2024
by
Antoni Baum
Committed by
GitHub
Mar 25, 2024
Browse files
Optimize `_get_ranks` in Sampler (#3623)
parent
64172a97
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+17
-10
No files found.
vllm/model_executor/layers/sampler.py
View file @
3a243095
...
...
@@ -506,22 +506,23 @@ def _sample(
# sampling_tensors)
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
List
[
int
]
)
->
torch
.
Tensor
:
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (
List[int]
): List of chosen token indices.
indices (
torch.Tensor
): List of chosen token indices.
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals
=
x
[
range
(
len
(
x
)),
indices
]
return
(
x
>
vals
[:,
None
]).
long
().
sum
(
1
)
+
1
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
indices
]
return
(
x
>
vals
[:,
None
]).
long
().
sum
(
1
).
add_
(
1
)
def
_get_logprobs
(
...
...
@@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
batched_logprobs_query_seq_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_seq_indices
,
device
=
logprobs
.
device
)
batched_logprobs_query_token_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_token_indices
,
device
=
logprobs
.
device
)
# Batched query for logprobs of selected token
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_seq_indices
,
batched_logprobs_query_token_indices
batched_logprobs_query_seq_indices
_gpu
,
batched_logprobs_query_token_indices
_gpu
]]
batched_ranks_query_result
=
_get_ranks
(
logprobs
[
batched_logprobs_query_seq_indices_gpu
],
batched_logprobs_query_token_indices_gpu
)
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
...
...
@@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
batched_ranks_query_result
=
_get_ranks
(
logprobs
[
batched_logprobs_query_seq_indices
],
batched_logprobs_query_token_indices
)
batched_ranks_query_result
=
batched_ranks_query_result
.
cpu
()
# Gather results
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment