Unverified Commit b3da9427 authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Incrementally decode token to reduce the overhead from Processor (#1129)

parent 93702e44
...@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine: ...@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine:
if self._remote_prefill and self._server_type == ServerType.GEN: if self._remote_prefill and self._server_type == ServerType.GEN:
ctx_response_obj = await self._get_remote_prefill_response(request) ctx_response_obj = await self._get_remote_prefill_response(request)
yield TRTLLMWorkerResponse(
request_id=request.id,
prompt_token_ids=ctx_response_obj.prompt_token_ids,
outputs=[asdict(ctx_response_obj.outputs[0])],
finished=ctx_response_obj.finished,
).model_dump_json(exclude_unset=True)
worker_inputs = ctx_response_obj.prompt_token_ids worker_inputs = ctx_response_obj.prompt_token_ids
disaggregated_params = ( disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params( DisaggregatedTypeConverter.to_llm_disaggregated_params(
......
...@@ -31,6 +31,7 @@ from common.protocol import ( ...@@ -31,6 +31,7 @@ from common.protocol import (
from common.utils import ConversationMessage from common.utils import ConversationMessage
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, tokenizer_factory
from tensorrt_llm.serve.openai_protocol import ( from tensorrt_llm.serve.openai_protocol import (
ChatCompletionLogProbs, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionLogProbsContent,
...@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import ( ...@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall, ToolCall,
UsageInfo, UsageInfo,
) )
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,22 +55,7 @@ class ChatProcessorMixin: ...@@ -57,22 +55,7 @@ class ChatProcessorMixin:
# model name for chat processor # model name for chat processor
self._model_name = self._engine_config.model_name self._model_name = self._engine_config.model_name
logger.info(f"Set model name: {self._model_name}") logger.info(f"Set model name: {self._model_name}")
# model for LLMAPI input self._tokenizer = tokenizer_factory(self._model_name)
self._model = self._model_name
if self._engine_config.model_path:
self._model = self._engine_config.model_path
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_path
)
logger.info(f"Using model from path: {self._engine_config.model_path}")
else:
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.model_name
)
if self._engine_config.extra_args.get("tokenizer", None):
self._tokenizer = AutoTokenizer.from_pretrained(
self._engine_config.extra_args.get("tokenizer", None)
)
self.chat_processor = ChatProcessor( self.chat_processor = ChatProcessor(
self._model_name, self._tokenizer, using_engine_generator self._model_name, self._tokenizer, using_engine_generator
) )
...@@ -109,7 +92,7 @@ class BaseChatProcessor: ...@@ -109,7 +92,7 @@ class BaseChatProcessor:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: TokenizerBase,
): ):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor): ...@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor):
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: TokenizerBase,
using_engine_generator: bool = False, using_engine_generator: bool = False,
): ):
super().__init__(model, tokenizer) super().__init__(model, tokenizer)
...@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor): ...@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor):
choice.finish_reason = output.finish_reason choice.finish_reason = output.finish_reason
choice.stop_reason = output.stop_reason choice.stop_reason = output.stop_reason
finish_reason_sent[i] = True finish_reason_sent[i] = True
if output.disaggregated_params is not None:
# Block the disaggregated params at processor level
pass
chunk = DynamoTRTLLMChatCompletionStreamResponse( chunk = DynamoTRTLLMChatCompletionStreamResponse(
id=request_id, id=request_id,
...@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor): ...@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor):
) )
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
tokenize=False, tokenize=True,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
...@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor): ...@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor):
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest( return TRTLLMWorkerRequest(
id=request.id, id=request.id,
model=request.model, model=request.model,
prompt=prompt,
sampling_params=asdict(sampling_params), sampling_params=asdict(sampling_params),
streaming=request.stream,
conversation=conversation, conversation=conversation,
disaggregated_params=request.disaggregated_params, disaggregated_params=request.disaggregated_params,
# NOTE: dont include the first token (e.g. <s>) when searching for a prefix match. We might want to exclude all special tokens at some point. tokens=Tokens(tokens=prompt),
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]),
) )
async def postprocess( async def postprocess(
...@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor): ...@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor):
conversation, conversation,
): ):
first_iteration = True first_iteration = True
last_text_len = 0
last_token_ids_len = 0
async for raw_response in engine_generator: async for raw_response in engine_generator:
if self.using_engine_generator: if self.using_engine_generator:
response = TRTLLMWorkerResponse( response = TRTLLMWorkerResponse(
...@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor): ...@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor):
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])] response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
else: else:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data()) response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode( response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"] response.outputs[0]["token_ids"][last_token_ids_len:]
) )
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response.outputs[0]["_last_text_len"] = last_text_len
response.outputs[0]["_last_token_ids_len"] = last_token_ids_len
last_text_len = len(response.outputs[0]["text"])
last_token_ids_len = len(response.outputs[0]["token_ids"])
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])] response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_chat_stream_response( response_data = self.create_chat_stream_response(
...@@ -380,7 +352,7 @@ class CompletionsProcessor: ...@@ -380,7 +352,7 @@ class CompletionsProcessor:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: TokenizerBase,
): ):
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -391,20 +363,15 @@ class CompletionsProcessor: ...@@ -391,20 +363,15 @@ class CompletionsProcessor:
# len(response.outputs) is always 1 # len(response.outputs) is always 1
for gen_idx, output in enumerate(response.outputs): for gen_idx, output in enumerate(response.outputs):
delta_text = output.text_diff text = output.text
if request.echo and not echoed[gen_idx]: if request.echo and not echoed[gen_idx]:
delta_text = request.prompt + delta_text text = request.prompt + text
echoed[gen_idx] = True
choice = DynamoTRTLLMCompletionResponseStreamChoice( choice = DynamoTRTLLMCompletionResponseStreamChoice(
index=gen_idx, text=text,
text=delta_text, index=output.index,
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
) )
if output.disaggregated_params is not None:
# Block the disagg_params
pass
chunk = DynamoTRTLLMCompletionStreamResponse( chunk = DynamoTRTLLMCompletionStreamResponse(
model=self.model, model=self.model,
choices=[choice], choices=[choice],
...@@ -423,14 +390,16 @@ class CompletionsProcessor: ...@@ -423,14 +390,16 @@ class CompletionsProcessor:
) )
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
sampling_params._setup(self.tokenizer)
sampling_params.stop = None
return TRTLLMWorkerRequest( return TRTLLMWorkerRequest(
id=request.id, id=request.id,
model=request.model, model=request.model,
prompt=prompt, streaming=request.stream,
sampling_params=asdict(sampling_params), sampling_params=asdict(sampling_params),
disaggregated_params=request.disaggregated_params, disaggregated_params=request.disaggregated_params,
tokens=Tokens(tokens=self.tokenizer.encode(prompt)[1:]), tokens=Tokens(tokens=self.tokenizer.encode(prompt)),
) )
async def postprocess( async def postprocess(
...@@ -440,8 +409,12 @@ class CompletionsProcessor: ...@@ -440,8 +409,12 @@ class CompletionsProcessor:
): ):
async for raw_response in engine_generator: async for raw_response in engine_generator:
response = TRTLLMWorkerResponse.model_validate_json(raw_response.data()) response = TRTLLMWorkerResponse.model_validate_json(raw_response.data())
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
last_token_ids_len = response.outputs[0]["_last_token_ids_len"]
response.outputs[0]["text"] = self.tokenizer.decode(
response.outputs[0]["token_ids"][last_token_ids_len:]
)
response.outputs = [TRTLLMWorkerResponseOutput(**response.outputs[0])]
response_data = self.create_completion_stream_response( response_data = self.create_completion_stream_response(
request, request,
response, response,
......
...@@ -51,6 +51,7 @@ class LLMAPIConfig: ...@@ -51,6 +51,7 @@ class LLMAPIConfig:
data = { data = {
"pytorch_backend_config": self.pytorch_backend_config, "pytorch_backend_config": self.pytorch_backend_config,
"kv_cache_config": self.kv_cache_config, "kv_cache_config": self.kv_cache_config,
"skip_tokenizer_init": self.skip_tokenizer_init,
} }
if self.extra_args: if self.extra_args:
data.update(self.extra_args) data.update(self.extra_args)
......
...@@ -17,7 +17,7 @@ import base64 ...@@ -17,7 +17,7 @@ import base64
import time import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, Union from typing import Any, List, Literal, Optional, TypeAlias, Union
import torch import torch
from common.utils import ConversationMessage from common.utils import ConversationMessage
...@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel): ...@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel):
disaggregated_params: Optional[DisaggregatedParams] = Field(default=None) disaggregated_params: Optional[DisaggregatedParams] = Field(default=None)
@dataclass(slots=True)
class Logprob:
"""Holds logprob and vocab rank for a token."""
logprob: float
rank: Optional[int] = None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs: TypeAlias = list[dict[int, Logprob]]
@dataclass @dataclass
class TRTLLMWorkerResponseOutput: class TRTLLMWorkerResponseOutput:
index: int index: int
text: str text: str = ""
token_ids: list[int] token_ids: Optional[List[int]] = field(default_factory=list)
logprobs: Optional[List[float]] = None
prompt_logprobs: Optional[List[float]] = None
cumulative_logprob: Optional[float] = None cumulative_logprob: Optional[float] = None
logprobs: Optional[TokenLogprobs] = field(default_factory=list)
prompt_logprobs: Optional[TokenLogprobs] = field(default_factory=list)
finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None finish_reason: Optional[Literal["stop", "length", "timeout", "cancelled"]] = None
stop_reason: Optional[Union[int, str]] = None stop_reason: Optional[Union[int, str]] = None
generation_logits: Optional[torch.Tensor] = None generation_logits: Optional[torch.Tensor] = None
disaggregated_params: Optional[DisaggregatedParams] = None disaggregated_params: Optional[DisaggregatedParams] = None
_last_text_len: int = field(default=0) # hidden fields for tracking the diffs
_last_token_ids_len: int = field(default=0) _last_text_len: int = field(default=0, init=True, repr=False)
_last_logprobs_len: int = field(default=0) _last_token_ids_len: int = field(default=0, init=True, repr=False)
_incremental_states: Optional[dict] = field(default=None) _last_logprobs_len: int = field(default=0, init=True, repr=False)
_postprocess_result: Optional[Any] = field(default=None) _incremental_states: Optional[dict] = field(default=None, init=True, repr=False)
# the result of result_handler passed to postprocess workers
text_diff: str = field(default="") _postprocess_result: Any = None
length: int = field(default=0)
def __post_init__(self):
self.text_diff = self.text[self._last_text_len :]
self.length = len(self.token_ids)
class TRTLLMWorkerResponse(BaseModel): class TRTLLMWorkerResponse(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment