Unverified Commit 20f7cc4c authored by Dan Lord's avatar Dan Lord Committed by GitHub
Browse files

Add `skip_special_tokens` sampling params (#1186)

parent 649aa730
...@@ -387,7 +387,7 @@ class LLMEngine: ...@@ -387,7 +387,7 @@ class LLMEngine:
child_seqs.append((parent, parent)) child_seqs.append((parent, parent))
for seq, _ in child_seqs: for seq, _ in child_seqs:
self._decode_sequence(seq) self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case # Non-beam search case
...@@ -621,7 +621,8 @@ class LLMEngine: ...@@ -621,7 +621,8 @@ class LLMEngine:
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now self.last_logging_time = now
def _decode_sequence(self, seq: Sequence) -> None: def _decode_sequence(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, (new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
...@@ -630,7 +631,7 @@ class LLMEngine: ...@@ -630,7 +631,7 @@ class LLMEngine:
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset, read_offset=seq.read_offset,
skip_special_tokens=True, skip_special_tokens=sampling_params.skip_special_tokens,
) )
if seq.tokens is None: if seq.tokens is None:
seq.tokens = new_tokens seq.tokens = new_tokens
......
...@@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -225,6 +225,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
top_k=request.top_k, top_k=request.top_k,
ignore_eos=request.ignore_eos, ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
...@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -426,6 +427,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
use_beam_search=request.use_beam_search, use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
) )
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
......
...@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -71,6 +71,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
...@@ -96,6 +97,7 @@ class CompletionRequest(BaseModel): ...@@ -96,6 +97,7 @@ class CompletionRequest(BaseModel):
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class LogProbs(BaseModel): class LogProbs(BaseModel):
......
...@@ -60,6 +60,8 @@ class SamplingParams: ...@@ -60,6 +60,8 @@ class SamplingParams:
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token. logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
Defaults to true.
""" """
def __init__( def __init__(
...@@ -79,6 +81,7 @@ class SamplingParams: ...@@ -79,6 +81,7 @@ class SamplingParams:
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> None: ) -> None:
self.n = n self.n = n
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
...@@ -103,6 +106,7 @@ class SamplingParams: ...@@ -103,6 +106,7 @@ class SamplingParams:
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
self.skip_special_tokens = skip_special_tokens
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
...@@ -196,4 +200,5 @@ class SamplingParams: ...@@ -196,4 +200,5 @@ class SamplingParams:
f"stop={self.stop}, " f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs})") f"logprobs={self.logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment