Unverified Commit cc0eaf12 authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] spec decode handle None entries in topk args in create_sequence_group_output (#7232)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
parent 955b5191
......@@ -343,3 +343,78 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
b=baseline_rank_to_logprob[rank],
abs_tol=1e-1,
)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"max_logprobs": 6,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("seed", [1])
def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
sampling_params = SamplingParams(
# Use smaller output len for fast test
max_tokens=7,
ignore_eos=True,
temperature=0.0,
logprobs=2,
)
spec_batch_logprobs = get_logprobs_from_llm_generator(
test_llm_generator, prompts, sampling_params)
baseline_batch_logprobs = get_logprobs_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_logprobs) == len(prompts)
assert len(spec_batch_logprobs) == len(prompts)
# For each sequence in the batch.
for _, (baseline_logprobs, spec_logprobs) in enumerate(
zip(baseline_batch_logprobs, spec_batch_logprobs)):
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
assert len(spec_pos_logprobs) == 1
spec_top_token_id = list(spec_pos_logprobs)[0]
spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
assert spec_top_logprob.logprob == 0.0
assert spec_top_logprob.rank == -1
# check that the chosen token matches the base model
baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
assert baseline_logprob.rank == 1
assert spec_top_logprob.decoded_token \
== baseline_logprob.decoded_token
......@@ -64,23 +64,25 @@ def create_sequence_group_output(
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
topk_token_ids (List[Optional[int]]): The list of top-k token ids.
topk_logprobs (List[Optional[float]]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[Optional[int], Logprob] = {
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
topk_token_id: Logprob(
logprob=topk_logprob if topk_logprob is not None else 0.0,
rank=topk_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
for topk_index, (topk_token_id, topk_logprob) \
in enumerate(zip(topk_token_ids, topk_logprobs)) \
if topk_token_id is not None
})
return CompletionSequenceGroupOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment