Unverified Commit d2c919dc authored by realliujiaxu's avatar realliujiaxu Committed by GitHub
Browse files

[bugfix] fix bug when top_logprobs=0 with spec decoding (#30059)


Signed-off-by: default avatarrealliujiaxu <realliujiaxu@163.com>
parent f3237f3f
...@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ...@@ -528,9 +528,11 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
), ),
], ],
) )
@pytest.mark.parametrize("top_logprobs", [0, 3])
def test_spec_decode_logprobs( def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode, logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str], model_setup: tuple[str, str, str],
top_logprobs: int,
): ):
"""Spec decode logprobs should match those of the base model. """Spec decode logprobs should match those of the base model.
...@@ -543,7 +545,7 @@ def test_spec_decode_logprobs( ...@@ -543,7 +545,7 @@ def test_spec_decode_logprobs(
prompt = "Hello world " * 50 prompt = "Hello world " * 50
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
) )
method, model_name, spec_model_name = model_setup method, model_name, spec_model_name = model_setup
max_model_len = 256 max_model_len = 256
......
...@@ -111,7 +111,7 @@ def create_sampling_metadata( ...@@ -111,7 +111,7 @@ def create_sampling_metadata(
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=None,
no_penalties=no_penalties, no_penalties=no_penalties,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=frequency_penalties, frequency_penalties=frequency_penalties,
......
...@@ -145,7 +145,7 @@ class RejectionSampler(nn.Module): ...@@ -145,7 +145,7 @@ class RejectionSampler(nn.Module):
) )
logprobs_tensors = None logprobs_tensors = None
if sampling_metadata.max_num_logprobs: if sampling_metadata.max_num_logprobs is not None:
logprobs_tensors = self._get_logprobs_tensors( logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs, sampling_metadata.max_num_logprobs,
metadata, metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment