Unverified Commit e8671783 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Incrementally decode output tokens (#121)

parent aedba6d5
...@@ -291,7 +291,7 @@ class Scheduler: ...@@ -291,7 +291,7 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence. # Append a new token to the sequence.
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
seq.append_token(output.output_token, output.logprobs) seq.append_token_id(output.output_token, output.logprobs)
return self.running.copy() return self.running.copy()
def free_seq(self, seq: Sequence) -> None: def free_seq(self, seq: Sequence) -> None:
......
...@@ -24,7 +24,7 @@ class SequenceData: ...@@ -24,7 +24,7 @@ class SequenceData:
self.output_token_ids: List[int] = [] self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
def append_token(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id) self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob self.cumulative_logprob += logprob
...@@ -64,6 +64,7 @@ class Sequence: ...@@ -64,6 +64,7 @@ class Sequence:
self.data = SequenceData(prompt_token_ids) self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: List[Dict[int, float]] = []
self.output_tokens: List[str] = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
...@@ -92,11 +93,15 @@ class Sequence: ...@@ -92,11 +93,15 @@ class Sequence:
last_block.append_tokens(token_ids[:num_empty_slots]) last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:] token_ids = token_ids[num_empty_slots:]
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None: def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
) -> None:
assert token_id in logprobs assert token_id in logprobs
self._append_tokens_to_blocks([token_id]) self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token(token_id, logprobs[token_id]) self.data.append_token_id(token_id, logprobs[token_id])
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
......
...@@ -14,7 +14,8 @@ from cacheflow.outputs import RequestOutput ...@@ -14,7 +14,8 @@ from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker from cacheflow.worker.worker import Worker
...@@ -185,18 +186,17 @@ class LLMServer: ...@@ -185,18 +186,17 @@ class LLMServer:
return request_outputs return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Batch-decode the sequence outputs. # Decode the sequence outputs.
seqs: List[Sequence] = []
for seq_group in seq_groups: for seq_group in seq_groups:
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING)) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output_tokens_per_seq = [] new_token, new_output_text = detokenize_incrementally(
for seq in seqs: self.tokenizer,
output_tokens_per_seq.append(seq.get_output_token_ids()) seq.output_tokens,
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq, seq.get_last_token_id(),
skip_special_tokens=True) skip_special_tokens=True,
# Update the sequences with the output texts. )
for seq, output_text in zip(seqs, output_texts): seq.output_tokens.append(new_token)
seq.output_text = output_text seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences. # Stop the sequences.
......
from typing import Union from typing import List, Tuple, Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast) PreTrainedTokenizerFast)
from cacheflow.logger import init_logger
logger = init_logger(__name__)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [ _MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf. # LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554 # See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
...@@ -17,5 +21,62 @@ def get_tokenizer( ...@@ -17,5 +21,62 @@ def get_tokenizer(
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER: if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if getattr(kwargs, "use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "
"bugs in the fast tokenizer.")
logger.info(
f"Using the slow tokenizer for {config.model_type} due to bugs in "
"the fast tokenizer. This could potentially lead to performance "
"degradation.")
kwargs["use_fast"] = False kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs) return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool,
) -> Tuple[str, str]:
"""Detokenizes the new token in conjuction with the previous output tokens.
NOTE: This function does not update prev_output_tokens.
Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
output_text = " ".join(sub_texts)
return new_token, output_text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment