Unverified Commit 99b4cf5f authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (#7218)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
parent e02ac556
...@@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the ...@@ -19,8 +19,12 @@ With those tests, we can say at least, MLPSpeculator would not break the
correctess for the target model outputs. correctess for the target model outputs.
""" """
from unittest.mock import patch
import pytest import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from .conftest import (run_equality_correctness_test, from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test) run_greedy_equality_correctness_test)
...@@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, ...@@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality when the vocab dimension is padded
"""
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
......
...@@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module): ...@@ -91,7 +91,7 @@ class LogitsProcessor(nn.Module):
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
if logits is not None: if logits is not None:
logits = logits[:, :self.org_vocab_size] logits = logits[..., :self.org_vocab_size]
return logits return logits
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds # Only perform shape/dtype/device checking in strict mode, as it adds
# overhead. # overhead.
if self._strict_mode: if self._strict_mode:
self._raise_if_incorrect_input(target_probs, bonus_token_ids, self._raise_if_incorrect_input(target_probs, draft_token_ids,
draft_probs, draft_token_ids) bonus_token_ids, draft_probs)
accepted, recovered_token_ids = ( accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling( self._batch_modified_rejection_sampling(
......
...@@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module): ...@@ -175,13 +175,14 @@ class MLPSpeculator(nn.Module):
states.add_(z, alpha=self.emb_weight / self.state_weight) states.add_(z, alpha=self.emb_weight / self.state_weight)
states = self.activation(self.ln[head_index](states)) # b k d states = self.activation(self.ln[head_index](states)) # b k d
# TODO: not yet supporting top_k_tokens_per_head
previous_hidden_states = states previous_hidden_states = states
# TODO: not yet supporting top_k_tokens_per_head
states = states.flatten(0, 1)
logits = self.logits_processor(self.head[head_index], states, logits = self.logits_processor(self.head[head_index], states,
sampling_metadata) sampling_metadata)
output = self.sampler(logits.flatten(0, 1), sampling_metadata) output = self.sampler(logits, sampling_metadata)
last_tokens = output.sampled_token_ids last_tokens = output.sampled_token_ids
next_tokens.append(output) next_tokens.append(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment