Unverified Commit 81661da7 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix `min_tokens` when `eos_token_id` is None (#4389)


Co-authored-by: default avatarDefTruth <31974251+deftruth@users.noreply.github.com>
parent dfea1731
......@@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
stop_token_ids: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
......@@ -216,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
sampling_params.eos_token_id = eos_token_id
sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
......@@ -461,10 +461,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)
tokens_to_check = sampling_params.all_stop_token_ids
if should_penalize:
for token_id in tokens_to_check:
......
......@@ -431,9 +431,10 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# Add the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)
......
......@@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
seqs_to_penalize.append(j)
if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
......
......@@ -185,8 +185,8 @@ class SamplingParams:
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment