Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f7bb584
Unverified
Commit
5f7bb584
authored
Sep 24, 2024
by
jiqing-feng
Committed by
GitHub
Sep 23, 2024
Browse files
Fix typical acceptance sampler with correct recovered token ids (#8562)
parent
b05f5c92
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
28 deletions
+17
-28
tests/samplers/test_typical_acceptance_sampler.py
tests/samplers/test_typical_acceptance_sampler.py
+8
-9
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+9
-19
No files found.
tests/samplers/test_typical_acceptance_sampler.py
View file @
5f7bb584
...
@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
...
@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
# Next only keep the first 2 draft tokens same as the zero temperature
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1
# draft tokens and the
recovered token and
rest as -1
draft_token_ids_to_replace
=
get_draft_token_ids
(
draft_token_ids_to_replace
=
get_draft_token_ids
(
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
batch_size
,
k
,
vocab_size
,
zero_temperature_token_ids
)
draft_token_ids
=
torch
.
cat
(
draft_token_ids
=
torch
.
cat
(
...
@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
...
@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
0
]
==
batch_size
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
output_token_ids
.
shape
[
1
]
==
(
k
+
1
)
assert
torch
.
all
(
output_token_ids
[:,
:
2
]
==
draft_token_ids
[:,
:
2
])
assert
torch
.
all
(
output_token_ids
[:,
:
2
]
==
draft_token_ids
[:,
:
2
])
assert
torch
.
all
(
output_token_ids
[:,
2
]
==
target_with_bonus_probs
.
argmax
(
-
1
)[:,
2
])
assert
torch
.
all
(
output_token_ids
[:,
-
3
:]
==
-
1
)
assert
torch
.
all
(
output_token_ids
[:,
-
3
:]
==
-
1
)
...
@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
...
@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_
replacement
_token_ids
(
seed
:
int
,
device
:
str
):
def
test_
get_recovered
_token_ids
(
seed
:
int
,
device
:
str
):
"""
"""
Test the TypicalAcceptanceSampler's method for generating
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
replacement token IDs.
This test verifies that the `_
replacement
_token_ids` method of the
This test verifies that the `_
get_recovered
_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as re
placement
s based on the target probability distribution.
as re
covered token ID
s based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
tokens with the highest probability for each sequence in the batch.
"""
"""
...
@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
...
@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
=
get_acceptance_sampler
(
strict_mode
=
True
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
typical_acceptance_sampler
.
init_gpu_tensors
(
device
=
device
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
target_probs
=
torch
.
rand
(
batch_size
,
k
,
vocab_size
,
dtype
=
torch
.
float32
)
expected_replacement_tokens
=
-
torch
.
ones
(
expected_replacement_tokens
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
(
batch_size
,
k
),
dtype
=
torch
.
long
)
expected_replacement_tokens
[:,
0
]
=
torch
.
argmax
(
target_probs
[:,
0
,
:],
dim
=
1
)
actual_replacement_tokens
=
(
actual_replacement_tokens
=
(
typical_acceptance_sampler
.
_
replacement
_token_ids
(
target_probs
))
typical_acceptance_sampler
.
_
get_recovered
_token_ids
(
target_probs
))
assert
torch
.
all
(
expected_replacement_tokens
==
actual_replacement_tokens
)
assert
torch
.
all
(
expected_replacement_tokens
==
actual_replacement_tokens
)
vllm/model_executor/layers/typical_acceptance_sampler.py
View file @
5f7bb584
...
@@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
target_probs
=
target_with_bonus_probs
[:,
:
-
1
]
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
draft_token_ids
)
recovered_token_ids
=
self
.
_
replacement
_token_ids
(
target_probs
)
recovered_token_ids
=
self
.
_
get_recovered
_token_ids
(
target_probs
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
draft_token_ids
,
bonus_token_ids
)
bonus_token_ids
)
...
@@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask
=
candidates_prob
>
threshold
accepted_mask
=
candidates_prob
>
threshold
return
accepted_mask
return
accepted_mask
def
_
replacement
_token_ids
(
self
,
target_probs
):
def
_
get_recovered
_token_ids
(
self
,
target_probs
):
"""
"""
Generate one replacement token ID for each sequence based on target
The recovered token ids will fill the first unmatched token
probabilities. The replacement token is used as the fallback option
by the target token.
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters
Parameters
----------
----------
...
@@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
...
@@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns
Returns
-------
-------
torch.Tensor
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
A tensor of shape (batch_size, k) with the recovered token
token IDs. Only the first column is set, and the rest of the
ids which are selected from target probs.
columns are filled with -1.
"""
"""
max_indices
=
torch
.
argmax
(
target_probs
[:,
0
,
:],
dim
=
1
)
max_indices
=
torch
.
argmax
(
target_probs
,
dim
=-
1
)
output
=
-
torch
.
ones
((
target_probs
.
shape
[
0
],
target_probs
.
shape
[
1
]),
dtype
=
self
.
token_id_dtype
,
return
max_indices
device
=
target_probs
.
device
)
output
[:,
0
]
=
max_indices
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment