"vscode:/vscode.git/clone" did not exist on "bc92107b03a3a87c6b41d98ac9581749f0800073"
Unverified Commit 8f8f96a6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the perf regression due to additional_stop_token_ids (#1773)

parent 05b3bf5e
...@@ -164,7 +164,7 @@ def get_tokenizer( ...@@ -164,7 +164,7 @@ def get_tokenizer(
"slowdown. Consider using a fast tokenizer instead." "slowdown. Consider using a fast tokenizer instead."
) )
handle_additional_stop_token_ids(tokenizer) attach_additional_stop_token_ids(tokenizer)
return tokenizer return tokenizer
...@@ -184,11 +184,11 @@ def get_processor( ...@@ -184,11 +184,11 @@ def get_processor(
**kwargs, **kwargs,
) )
handle_additional_stop_token_ids(processor.tokenizer) attach_additional_stop_token_ids(processor.tokenizer)
return processor return processor
def handle_additional_stop_token_ids(tokenizer): def attach_additional_stop_token_ids(tokenizer):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use. # Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if "<|eom_id|>" in tokenizer.get_added_vocab(): if "<|eom_id|>" in tokenizer.get_added_vocab():
tokenizer.additional_stop_token_ids = set( tokenizer.additional_stop_token_ids = set(
......
...@@ -42,11 +42,11 @@ class Sampler(nn.Module): ...@@ -42,11 +42,11 @@ class Sampler(nn.Module):
logits = logits.contiguous() logits = logits.contiguous()
if self.use_nan_detectioin and torch.any(torch.isnan(logits)): if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
exit(1) if crash_on_warning else None
logger.warning("Detected errors during sampling! NaN in the logits.") logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where( logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits torch.isnan(logits), torch.full_like(logits, -1e5), logits
) )
exit(1) if crash_on_warning else None
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
......
...@@ -334,15 +334,20 @@ class Req: ...@@ -334,15 +334,20 @@ class Req:
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
matched_eos = last_token_id in self.sampling_params.stop_token_ids matched_eos = False
# Check stop token ids
if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
if self.tokenizer is not None: if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id matched_eos |= last_token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
if matched_eos and not self.sampling_params.ignore_eos: if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:
tail_str = self.tokenizer.decode( tail_str = self.tokenizer.decode(
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
......
...@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=[ sequences=[
torch.tensor( torch.tensor(
data=list( data=(
req.sampling_params.stop_token_ids list(
(req.sampling_params.stop_token_ids or set())
| (req.tokenizer.additional_stop_token_ids or set())
| {req.tokenizer.eos_token_id} | {req.tokenizer.eos_token_id}
)
), ),
dtype=torch.int64, dtype=torch.int64,
device=self.orchestrator.device, device=self.orchestrator.device,
......
...@@ -50,10 +50,10 @@ class SamplingParams: ...@@ -50,10 +50,10 @@ class SamplingParams:
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty self.repetition_penalty = repetition_penalty
self.stop_strs = stop self.stop_strs = stop
if stop_token_ids is None: if stop_token_ids:
self.stop_token_ids = set()
else:
self.stop_token_ids = set(stop_token_ids) self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
...@@ -134,10 +134,6 @@ class SamplingParams: ...@@ -134,10 +134,6 @@ class SamplingParams:
stop_str_max_len = max(stop_str_max_len, len(stop_str)) stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len self.stop_str_max_len = stop_str_max_len
# Process stop token ids
if tokenizer and tokenizer.additional_stop_token_ids:
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
def to_srt_kwargs(self): def to_srt_kwargs(self):
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment