Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b3da9427
"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "74fcd4a9b949cb236928b003a154cf698e38a295"
Unverified
Commit
b3da9427
authored
May 20, 2025
by
Tanmay Verma
Committed by
GitHub
May 20, 2025
Browse files
fix: Incrementally decode token to reduce the overhead from Processor (#1129)
parent
93702e44
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
69 deletions
+57
-69
examples/tensorrt_llm/common/base_engine.py
examples/tensorrt_llm/common/base_engine.py
+7
-0
examples/tensorrt_llm/common/chat_processor.py
examples/tensorrt_llm/common/chat_processor.py
+25
-52
examples/tensorrt_llm/common/parser.py
examples/tensorrt_llm/common/parser.py
+1
-0
examples/tensorrt_llm/common/protocol.py
examples/tensorrt_llm/common/protocol.py
+24
-17
No files found.
examples/tensorrt_llm/common/base_engine.py
View file @
b3da9427
...
...
@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine:
if
self
.
_remote_prefill
and
self
.
_server_type
==
ServerType
.
GEN
:
ctx_response_obj
=
await
self
.
_get_remote_prefill_response
(
request
)
yield
TRTLLMWorkerResponse
(
request_id
=
request
.
id
,
prompt_token_ids
=
ctx_response_obj
.
prompt_token_ids
,
outputs
=
[
asdict
(
ctx_response_obj
.
outputs
[
0
])],
finished
=
ctx_response_obj
.
finished
,
).
model_dump_json
(
exclude_unset
=
True
)
worker_inputs
=
ctx_response_obj
.
prompt_token_ids
disaggregated_params
=
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
...
...
examples/tensorrt_llm/common/chat_processor.py
View file @
b3da9427
...
...
@@ -31,6 +31,7 @@ from common.protocol import (
from
common.utils
import
ConversationMessage
from
openai.types.chat
import
ChatCompletionMessageParam
from
tensorrt_llm.llmapi.llm
import
RequestOutput
from
tensorrt_llm.llmapi.tokenizer
import
TokenizerBase
,
tokenizer_factory
from
tensorrt_llm.serve.openai_protocol
import
(
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
...
...
@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall
,
UsageInfo
,
)
from
transformers
import
AutoTokenizer
from
transformers.tokenization_utils
import
PreTrainedTokenizer
from
transformers.tokenization_utils_fast
import
PreTrainedTokenizerFast
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -57,22 +55,7 @@ class ChatProcessorMixin:
# model name for chat processor
self
.
_model_name
=
self
.
_engine_config
.
model_name
logger
.
info
(
f
"Set model name:
{
self
.
_model_name
}
"
)
# model for LLMAPI input
self
.
_model
=
self
.
_model_name
if
self
.
_engine_config
.
model_path
:
self
.
_model
=
self
.
_engine_config
.
model_path
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
model_path
)
logger
.
info
(
f
"Using model from path:
{
self
.
_engine_config
.
model_path
}
"
)
else
:
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
model_name
)
if
self
.
_engine_config
.
extra_args
.
get
(
"tokenizer"
,
None
):
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
extra_args
.
get
(
"tokenizer"
,
None
)
)
self
.
_tokenizer
=
tokenizer_factory
(
self
.
_model_name
)
self
.
chat_processor
=
ChatProcessor
(
self
.
_model_name
,
self
.
_tokenizer
,
using_engine_generator
)
...
...
@@ -109,7 +92,7 @@ class BaseChatProcessor:
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor):
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
using_engine_generator
:
bool
=
False
,
):
super
().
__init__
(
model
,
tokenizer
)
...
...
@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor):
choice
.
finish_reason
=
output
.
finish_reason
choice
.
stop_reason
=
output
.
stop_reason
finish_reason_sent
[
i
]
=
True
if
output
.
disaggregated_params
is
not
None
:
# Block the disaggregated params at processor level
pass
chunk
=
DynamoTRTLLMChatCompletionStreamResponse
(
id
=
request_id
,
...
...
@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor):
)
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
tokenize
=
Fals
e
,
tokenize
=
Tru
e
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
...
...
@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor):
**
(
request
.
chat_template_kwargs
or
{}),
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
.
_setup
(
self
.
tokenizer
)
sampling_params
.
stop
=
None
return
TRTLLMWorkerRequest
(
id
=
request
.
id
,
model
=
request
.
model
,
prompt
=
prompt
,
sampling_params
=
asdict
(
sampling_params
),
streaming
=
request
.
stream
,
conversation
=
conversation
,
disaggregated_params
=
request
.
disaggregated_params
,
# NOTE: dont include the first token (e.g. <s>) when searching for a prefix match. We might want to exclude all special tokens at some point.
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)[
1
:]),
tokens
=
Tokens
(
tokens
=
prompt
),
)
async
def
postprocess
(
...
...
@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor):
conversation
,
):
first_iteration
=
True
last_text_len
=
0
last_token_ids_len
=
0
async
for
raw_response
in
engine_generator
:
if
self
.
using_engine_generator
:
response
=
TRTLLMWorkerResponse
(
...
...
@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor):
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
else
:
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
last_token_ids_len
=
response
.
outputs
[
0
][
"_last_token_ids_len"
]
response
.
outputs
[
0
][
"text"
]
=
self
.
tokenizer
.
decode
(
response
.
outputs
[
0
][
"token_ids"
]
response
.
outputs
[
0
][
"token_ids"
]
[
last_token_ids_len
:]
)
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response
.
outputs
[
0
][
"_last_text_len"
]
=
last_text_len
response
.
outputs
[
0
][
"_last_token_ids_len"
]
=
last_token_ids_len
last_text_len
=
len
(
response
.
outputs
[
0
][
"text"
])
last_token_ids_len
=
len
(
response
.
outputs
[
0
][
"token_ids"
])
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response_data
=
self
.
create_chat_stream_response
(
...
...
@@ -380,7 +352,7 @@ class CompletionsProcessor:
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
):
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -391,20 +363,15 @@ class CompletionsProcessor:
# len(response.outputs) is always 1
for
gen_idx
,
output
in
enumerate
(
response
.
outputs
):
delta_
text
=
output
.
text
_diff
text
=
output
.
text
if
request
.
echo
and
not
echoed
[
gen_idx
]:
delta_text
=
request
.
prompt
+
delta_text
echoed
[
gen_idx
]
=
True
text
=
request
.
prompt
+
text
choice
=
DynamoTRTLLMCompletionResponseStreamChoice
(
index
=
gen_idx
,
text
=
delta_t
ex
t
,
text
=
text
,
index
=
output
.
ind
ex
,
stop_reason
=
output
.
stop_reason
,
finish_reason
=
output
.
finish_reason
,
)
if
output
.
disaggregated_params
is
not
None
:
# Block the disagg_params
pass
chunk
=
DynamoTRTLLMCompletionStreamResponse
(
model
=
self
.
model
,
choices
=
[
choice
],
...
...
@@ -423,14 +390,16 @@ class CompletionsProcessor:
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
.
_setup
(
self
.
tokenizer
)
sampling_params
.
stop
=
None
return
TRTLLMWorkerRequest
(
id
=
request
.
id
,
model
=
request
.
model
,
prompt
=
prompt
,
streaming
=
request
.
stream
,
sampling_params
=
asdict
(
sampling_params
),
disaggregated_params
=
request
.
disaggregated_params
,
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)
[
1
:]
),
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)),
)
async
def
postprocess
(
...
...
@@ -440,8 +409,12 @@ class CompletionsProcessor:
):
async
for
raw_response
in
engine_generator
:
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
last_token_ids_len
=
response
.
outputs
[
0
][
"_last_token_ids_len"
]
response
.
outputs
[
0
][
"text"
]
=
self
.
tokenizer
.
decode
(
response
.
outputs
[
0
][
"token_ids"
][
last_token_ids_len
:]
)
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response_data
=
self
.
create_completion_stream_response
(
request
,
response
,
...
...
examples/tensorrt_llm/common/parser.py
View file @
b3da9427
...
...
@@ -51,6 +51,7 @@ class LLMAPIConfig:
data
=
{
"pytorch_backend_config"
:
self
.
pytorch_backend_config
,
"kv_cache_config"
:
self
.
kv_cache_config
,
"skip_tokenizer_init"
:
self
.
skip_tokenizer_init
,
}
if
self
.
extra_args
:
data
.
update
(
self
.
extra_args
)
...
...
examples/tensorrt_llm/common/protocol.py
View file @
b3da9427
...
...
@@ -17,7 +17,7 @@ import base64
import
time
import
uuid
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
TypeAlias
,
Union
import
torch
from
common.utils
import
ConversationMessage
...
...
@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel):
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
Field
(
default
=
None
)
@
dataclass
(
slots
=
True
)
class
Logprob
:
"""Holds logprob and vocab rank for a token."""
logprob
:
float
rank
:
Optional
[
int
]
=
None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs
:
TypeAlias
=
list
[
dict
[
int
,
Logprob
]]
@
dataclass
class
TRTLLMWorkerResponseOutput
:
index
:
int
text
:
str
token_ids
:
list
[
int
]
logprobs
:
Optional
[
List
[
float
]]
=
None
prompt_logprobs
:
Optional
[
List
[
float
]]
=
None
text
:
str
=
""
token_ids
:
Optional
[
List
[
int
]]
=
field
(
default_factory
=
list
)
cumulative_logprob
:
Optional
[
float
]
=
None
logprobs
:
Optional
[
TokenLogprobs
]
=
field
(
default_factory
=
list
)
prompt_logprobs
:
Optional
[
TokenLogprobs
]
=
field
(
default_factory
=
list
)
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"timeout"
,
"cancelled"
]]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
generation_logits
:
Optional
[
torch
.
Tensor
]
=
None
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
_last_text_len
:
int
=
field
(
default
=
0
)
_last_token_ids_len
:
int
=
field
(
default
=
0
)
_last_logprobs_len
:
int
=
field
(
default
=
0
)
_incremental_states
:
Optional
[
dict
]
=
field
(
default
=
None
)
_postprocess_result
:
Optional
[
Any
]
=
field
(
default
=
None
)
text_diff
:
str
=
field
(
default
=
""
)
length
:
int
=
field
(
default
=
0
)
def
__post_init__
(
self
):
self
.
text_diff
=
self
.
text
[
self
.
_last_text_len
:]
self
.
length
=
len
(
self
.
token_ids
)
# hidden fields for tracking the diffs
_last_text_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_last_token_ids_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_last_logprobs_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_incremental_states
:
Optional
[
dict
]
=
field
(
default
=
None
,
init
=
True
,
repr
=
False
)
# the result of result_handler passed to postprocess workers
_postprocess_result
:
Any
=
None
class
TRTLLMWorkerResponse
(
BaseModel
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment