Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3a243095
Unverified
Commit
3a243095
authored
Mar 25, 2024
by
Antoni Baum
Committed by
GitHub
Mar 25, 2024
Browse files
Optimize `_get_ranks` in Sampler (#3623)
parent
64172a97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
10 deletions
+17
-10
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+17
-10
No files found.
vllm/model_executor/layers/sampler.py
View file @
3a243095
...
@@ -506,22 +506,23 @@ def _sample(
...
@@ -506,22 +506,23 @@ def _sample(
# sampling_tensors)
# sampling_tensors)
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
List
[
int
]
)
->
torch
.
Tensor
:
def
_get_ranks
(
x
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
where N is the no. of tokens and M is the vocab dim.
indices (
List[int]
): List of chosen token indices.
indices (
torch.Tensor
): List of chosen token indices.
Returns:
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
of the chosen token in the input logprob tensor.
"""
"""
vals
=
x
[
range
(
len
(
x
)),
indices
]
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
return
(
x
>
vals
[:,
None
]).
long
().
sum
(
1
)
+
1
indices
]
return
(
x
>
vals
[:,
None
]).
long
().
sum
(
1
).
add_
(
1
)
def
_get_logprobs
(
def
_get_logprobs
(
...
@@ -561,12 +562,21 @@ def _get_logprobs(
...
@@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
assert
sample_idx
==
logprobs
.
size
(
0
)
assert
sample_idx
==
logprobs
.
size
(
0
)
batched_logprobs_query_seq_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_seq_indices
,
device
=
logprobs
.
device
)
batched_logprobs_query_token_indices_gpu
=
torch
.
tensor
(
batched_logprobs_query_token_indices
,
device
=
logprobs
.
device
)
# Batched query for logprobs of selected token
# Batched query for logprobs of selected token
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_result
=
logprobs
[[
batched_logprobs_query_seq_indices
,
batched_logprobs_query_seq_indices
_gpu
,
batched_logprobs_query_token_indices
batched_logprobs_query_token_indices
_gpu
]]
]]
batched_ranks_query_result
=
_get_ranks
(
logprobs
[
batched_logprobs_query_seq_indices_gpu
],
batched_logprobs_query_token_indices_gpu
)
# Batched query for logprobs of topk tokens
# Batched query for logprobs of topk tokens
if
largest_num_logprobs
>
0
:
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
...
@@ -578,10 +588,7 @@ def _get_logprobs(
...
@@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs
,
top_token_ids
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
batched_logprobs_query_result
=
batched_logprobs_query_result
.
cpu
()
batched_ranks_query_result
=
batched_ranks_query_result
.
cpu
()
batched_ranks_query_result
=
_get_ranks
(
logprobs
[
batched_logprobs_query_seq_indices
],
batched_logprobs_query_token_indices
)
# Gather results
# Gather results
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
result_prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment