Unverified Commit b3da9427 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Incrementally decode token to reduce the overhead from Processor (#1129)

parent 93702e44
......@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine:
if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response_obj = await self._get_remote_prefill_response(request)
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=ctx_response_obj.prompt_token_ids,
outputs=[asdict(ctx_response_obj.outputs[0])],
finished=ctx_response_obj.finished,
).model_dump_json(exclude_unset=True)
worker_inputs = ctx_response_obj.prompt_token_ids
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
......
......@@ -31,6 +31,7 @@ from common.protocol import (
from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, tokenizer_factory
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
......@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall,
UsageInfo,
)
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......@@ -57,22 +55,7 @@ class ChatProcessorMixin:
# model name for chat processor
self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self._tokenizer = tokenizer_factory(self._model_name)
self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator
)
......@@ -109,7 +92,7 @@ class BaseChatProcessor:
def __init__(
self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
......@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor):
def __init__(
self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: TokenizerBase,
using_engine_generator: bool = False,
):
super().__init__(model, tokenizer)
......@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor):
choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True
if output.disaggregated_params is not None:
# Block the disaggregated params at processor level
pass
chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id,
......@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor):
)
prompt = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
tokenize=True,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
......@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor):
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
prompt=prompt,
sampling_params=asdict(sampling_params),
streaming=request.stream,
conversation=conversation,
disaggregated_params=request.disaggregated_params,
# NOTE: dont include the first token (e.g. <s>) when searching for a prefix match. We might want to exclude all special tokens at some point.
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]),
tokens=Tokens(tokens=prompt),
)
async def postprocess(
......@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor):
conversation,
):
first_iteration = True
last_text_len = 0
last_token_ids_len = 0
async for raw_response in engine_generator:
if self.using_engine_generator:
response = TRTLLMWorkerResponse(
......@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor):
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"]
response.outputs[0]["token_ids"][last_token_ids_len:]
)
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response.outputs[0]["_last_text_len"] = last_text_len
response.outputs[0]["_last_token_ids_len"] = last_token_ids_len
last_text_len = len(response.outputs[0]["text"])
last_token_ids_len = len(response.outputs[0]["token_ids"])
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_chat_stream_response(
......@@ -380,7 +352,7 @@ class CompletionsProcessor:
def __init__(
self,
model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: TokenizerBase,
):
self.model = model
self.tokenizer = tokenizer
......@@ -391,20 +363,15 @@ class CompletionsProcessor:
# len(response.outputs) is always 1
for gen_idx, output in enumerate(response.outputs):
delta_text = output.text_diff
text = output.text
if request.echo and not echoed[gen_idx]:
delta_text = request.prompt + delta_text
echoed[gen_idx] = True
text = request.prompt + text
choice = DynamoTRTLLMCompletionResponseStreamChoice(
index=gen_idx,
text=delta_text,
text=text,
index=output.index,
stop_reason=output.stop_reason,
finish_reason=output.finish_reason,
)
if output.disaggregated_params is not None:
# Block the disagg_params
pass
chunk = DynamoTRTLLMCompletionStreamResponse(
model=self.model,
choices=[choice],
......@@ -423,14 +390,16 @@ class CompletionsProcessor:
)
sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest(
id=request.id,
model=request.model,
prompt=prompt,
streaming=request.stream,
sampling_params=asdict(sampling_params),
disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]),
tokens=Tokens(tokens=self.tokenizer.encode(prompt)),
)
async def postprocess(
......@@ -440,8 +409,12 @@ class CompletionsProcessor:
):
async for raw_response in engine_generator:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_completion_stream_response(
request,
response,
......
......@@ -51,6 +51,7 @@ class LLMAPIConfig:
data = {
"pytorch_backend_config": self.pytorch_backend_config,
"kv_cache_config": self.kv_cache_config,
"skip_tokenizer_init": self.skip_tokenizer_init,
}
if self.extra_args:
data.update(self.extra_args)
......
......@@ -17,7 +17,7 @@ import base64
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, TypeAlias, Union
import torch
from common.utils import ConversationMessage
......@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
@dataclass(slots=True)
class Logprob:
"""Holds logprob and vocab rank for a token."""
logprob: float
rank: Optional[int] = None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs: TypeAlias = list[dict[int, Logprob]]
@dataclass
class TRTLLMWorkerResponseOutput:
index: int
text: str
token_ids: list[int]
logprobs: Optional[List[float]] = None
prompt_logprobs: Optional[List[float]] = None
text: str = ""
token_ids: Optional[List[int]] = field(default_factory=list)
cumulative_logprob: Optional[float] = None
logprobs: Optional[TokenLogprobs] = field(default_factory=list)
prompt_logprobs: Optional[TokenLogprobs] = field(default_factory=list)
finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None
stop_reason: Optional[Union[int, str]] = None
generation_logits: Optional[torch.Tensor] = None
disaggregated_params: Optional[DisaggregatedParams] = None
_last_text_len: int = field(default=0)
_last_token_ids_len: int = field(default=0)
_last_logprobs_len: int = field(default=0)
_incremental_states: Optional[dict] = field(default=None)
_postprocess_result: Optional[Any] = field(default=None)
text_diff: str = field(default="")
length: int = field(default=0)
def __post_init__(self):
self.text_diff = self.text[self._last_text_len :]
self.length = len(self.token_ids)
# hidden fields for tracking the diffs
_last_text_len: int = field(default=0, init=True, repr=False)
_last_token_ids_len: int = field(default=0, init=True, repr=False)
_last_logprobs_len: int = field(default=0, init=True, repr=False)
_incremental_states: Optional[dict] = field(default=None, init=True, repr=False)
# the result of result_handler passed to postprocess workers
_postprocess_result: Any = None
class TRTLLMWorkerResponse(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment