Unverified Commit 5f7bb584 authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

Fix typical acceptance sampler with correct recovered token ids (#8562)

parent b05f5c92
...@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str): ...@@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str):
# Next only keep the first 2 draft tokens same as the zero temperature # Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the # tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the # response we will expect the first 2 tokens to be the same as the
# draft tokens and the rest as -1 # draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids( draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids) batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat( draft_token_ids = torch.cat(
...@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str): ...@@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str):
assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1) assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1) assert torch.all(output_token_ids[:, -3:] == -1)
...@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): ...@@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
@pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_replacement_token_ids(seed: int, device: str): def test_get_recovered_token_ids(seed: int, device: str):
""" """
Test the TypicalAcceptanceSampler's method for generating Test the TypicalAcceptanceSampler's method for generating
replacement token IDs. replacement token IDs.
This test verifies that the `_replacement_token_ids` method of the This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used TypicalAcceptanceSampler correctly identifies the token IDs to be used
as replacements based on the target probability distribution. as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch. tokens with the highest probability for each sequence in the batch.
""" """
...@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str): ...@@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str):
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device) typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones( expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
(batch_size, k), dtype=torch.long)
expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
dim=1)
actual_replacement_tokens = ( actual_replacement_tokens = (
typical_acceptance_sampler._replacement_token_ids(target_probs)) typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens) assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
...@@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -80,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs = target_with_bonus_probs[:, :-1] target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs, accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids) draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs) recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids, output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids, draft_token_ids,
bonus_token_ids) bonus_token_ids)
...@@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -148,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask = candidates_prob > threshold accepted_mask = candidates_prob > threshold
return accepted_mask return accepted_mask
def _replacement_token_ids(self, target_probs): def _get_recovered_token_ids(self, target_probs):
""" """
Generate one replacement token ID for each sequence based on target The recovered token ids will fill the first unmatched token
probabilities. The replacement token is used as the fallback option by the target token.
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters Parameters
---------- ----------
...@@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -168,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns Returns
------- -------
torch.Tensor torch.Tensor
A tensor of shape (batch_size, k) with the replacement A tensor of shape (batch_size, k) with the recovered token
token IDs. Only the first column is set, and the rest of the ids which are selected from target probs.
columns are filled with -1.
""" """
max_indices = torch.argmax(target_probs[:, 0, :], dim=1) max_indices = torch.argmax(target_probs, dim=-1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype, return max_indices
device=target_probs.device)
output[:, 0] = max_indices
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment