Unverified Commit d2c919dc authored by realliujiaxu's avatar realliujiaxu Committed by GitHub
Browse files

[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)


Signed-off-by: default avatarrealliujiaxu <realliujiaxu@163.com>
parent f3237f3f
......@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
),
],
)
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
top_logprobs: int,
):
"""Spec decode logprobs should match those of the base model.
......@@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
......
......@@ -111,7 +111,7 @@ def create_sampling_metadata(
top_p=top_p,
top_k=top_k,
generators=generators,
max_num_logprobs=0,
max_num_logprobs=None,
no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties,
......
......@@ -145,7 +145,7 @@ class RejectionSampler(nn.Module):
)
logprobs_tensors = None
if sampling_metadata.max_num_logprobs:
if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment