Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f205c098
Unverified
Commit
f205c098
authored
Aug 29, 2024
by
Jonas M. Kübler
Committed by
GitHub
Aug 28, 2024
Browse files
[Bugfix] Unify rank computation across regular decoding and speculative decoding (#7899)
parent
ef99a787
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
3 deletions
+22
-3
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+20
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+2
-2
No files found.
tests/spec_decode/test_utils.py
View file @
f205c098
...
...
@@ -4,10 +4,12 @@ import pytest
import
torch
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.sampler
import
_get_ranks
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
from
vllm.sequence
import
SequenceGroupMetadata
,
get_all_seq_ids
from
vllm.spec_decode.util
import
split_batch_by_proposal_len
from
vllm.spec_decode.util
import
(
get_sampled_token_logprobs
,
split_batch_by_proposal_len
)
def
test_get_all_seq_ids
():
...
...
@@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method):
return
sampler
else
:
raise
ValueError
(
f
"Invalid sampler name
{
acceptance_sampler_method
}
"
)
def
test_get_sampled_token_logprobs
():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor
=
torch
.
tensor
(
[[[
-
.
1
,
-
.
1
]]
*
2
])
# shape (num_steps, batch_size, vocab_size)
sampled_token_tensor
=
torch
.
tensor
([[
1
,
0
]])
# shape (num_steps, batch_size)
ranks_spec_dec
,
_
=
get_sampled_token_logprobs
(
logprob_tensor
,
sampled_token_tensor
)
ranks_regular
=
_get_ranks
(
logprob_tensor
.
reshape
((
2
,
-
1
)),
sampled_token_tensor
.
reshape
(
-
1
))
assert
torch
.
equal
(
ranks_spec_dec
.
reshape
(
-
1
),
ranks_regular
)
vllm/spec_decode/util.py
View file @
f205c098
...
...
@@ -43,8 +43,8 @@ def get_sampled_token_logprobs(
sampled_token_ids
,
]
expanded_selected_logprobs
=
selected_logprobs
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
vocab_size
)
sampled_token_ids_ranks
=
(
logprob_tensor
>
=
expanded_selected_logprobs
).
sum
(
-
1
)
sampled_token_ids_ranks
=
(
logprob_tensor
>
expanded_selected_logprobs
).
sum
(
-
1
)
.
add_
(
1
)
return
sampled_token_ids_ranks
,
selected_logprobs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment