Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b3da9427
Unverified
Commit
b3da9427
authored
May 20, 2025
by
Tanmay Verma
Committed by
GitHub
May 20, 2025
Browse files
fix: Incrementally decode token to reduce the overhead from Processor (#1129)
parent
93702e44
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
69 deletions
+57
-69
examples/tensorrt_llm/common/base_engine.py
examples/tensorrt_llm/common/base_engine.py
+7
-0
examples/tensorrt_llm/common/chat_processor.py
examples/tensorrt_llm/common/chat_processor.py
+25
-52
examples/tensorrt_llm/common/parser.py
examples/tensorrt_llm/common/parser.py
+1
-0
examples/tensorrt_llm/common/protocol.py
examples/tensorrt_llm/common/protocol.py
+24
-17
No files found.
examples/tensorrt_llm/common/base_engine.py
View file @
b3da9427
...
@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine:
...
@@ -511,6 +511,13 @@ class BaseTensorrtLLMEngine:
if
self
.
_remote_prefill
and
self
.
_server_type
==
ServerType
.
GEN
:
if
self
.
_remote_prefill
and
self
.
_server_type
==
ServerType
.
GEN
:
ctx_response_obj
=
await
self
.
_get_remote_prefill_response
(
request
)
ctx_response_obj
=
await
self
.
_get_remote_prefill_response
(
request
)
yield
TRTLLMWorkerResponse
(
request_id
=
request
.
id
,
prompt_token_ids
=
ctx_response_obj
.
prompt_token_ids
,
outputs
=
[
asdict
(
ctx_response_obj
.
outputs
[
0
])],
finished
=
ctx_response_obj
.
finished
,
).
model_dump_json
(
exclude_unset
=
True
)
worker_inputs
=
ctx_response_obj
.
prompt_token_ids
worker_inputs
=
ctx_response_obj
.
prompt_token_ids
disaggregated_params
=
(
disaggregated_params
=
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
...
...
examples/tensorrt_llm/common/chat_processor.py
View file @
b3da9427
...
@@ -31,6 +31,7 @@ from common.protocol import (
...
@@ -31,6 +31,7 @@ from common.protocol import (
from
common.utils
import
ConversationMessage
from
common.utils
import
ConversationMessage
from
openai.types.chat
import
ChatCompletionMessageParam
from
openai.types.chat
import
ChatCompletionMessageParam
from
tensorrt_llm.llmapi.llm
import
RequestOutput
from
tensorrt_llm.llmapi.llm
import
RequestOutput
from
tensorrt_llm.llmapi.tokenizer
import
TokenizerBase
,
tokenizer_factory
from
tensorrt_llm.serve.openai_protocol
import
(
from
tensorrt_llm.serve.openai_protocol
import
(
ChatCompletionLogProbs
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionLogProbsContent
,
...
@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import (
...
@@ -41,9 +42,6 @@ from tensorrt_llm.serve.openai_protocol import (
ToolCall
,
ToolCall
,
UsageInfo
,
UsageInfo
,
)
)
from
transformers
import
AutoTokenizer
from
transformers.tokenization_utils
import
PreTrainedTokenizer
from
transformers.tokenization_utils_fast
import
PreTrainedTokenizerFast
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -57,22 +55,7 @@ class ChatProcessorMixin:
...
@@ -57,22 +55,7 @@ class ChatProcessorMixin:
# model name for chat processor
# model name for chat processor
self
.
_model_name
=
self
.
_engine_config
.
model_name
self
.
_model_name
=
self
.
_engine_config
.
model_name
logger
.
info
(
f
"Set model name:
{
self
.
_model_name
}
"
)
logger
.
info
(
f
"Set model name:
{
self
.
_model_name
}
"
)
# model for LLMAPI input
self
.
_tokenizer
=
tokenizer_factory
(
self
.
_model_name
)
self
.
_model
=
self
.
_model_name
if
self
.
_engine_config
.
model_path
:
self
.
_model
=
self
.
_engine_config
.
model_path
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
model_path
)
logger
.
info
(
f
"Using model from path:
{
self
.
_engine_config
.
model_path
}
"
)
else
:
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
model_name
)
if
self
.
_engine_config
.
extra_args
.
get
(
"tokenizer"
,
None
):
self
.
_tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
_engine_config
.
extra_args
.
get
(
"tokenizer"
,
None
)
)
self
.
chat_processor
=
ChatProcessor
(
self
.
chat_processor
=
ChatProcessor
(
self
.
_model_name
,
self
.
_tokenizer
,
using_engine_generator
self
.
_model_name
,
self
.
_tokenizer
,
using_engine_generator
)
)
...
@@ -109,7 +92,7 @@ class BaseChatProcessor:
...
@@ -109,7 +92,7 @@ class BaseChatProcessor:
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
):
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -163,7 +146,7 @@ class ChatProcessor(BaseChatProcessor):
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
using_engine_generator
:
bool
=
False
,
using_engine_generator
:
bool
=
False
,
):
):
super
().
__init__
(
model
,
tokenizer
)
super
().
__init__
(
model
,
tokenizer
)
...
@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -268,9 +251,6 @@ class ChatProcessor(BaseChatProcessor):
choice
.
finish_reason
=
output
.
finish_reason
choice
.
finish_reason
=
output
.
finish_reason
choice
.
stop_reason
=
output
.
stop_reason
choice
.
stop_reason
=
output
.
stop_reason
finish_reason_sent
[
i
]
=
True
finish_reason_sent
[
i
]
=
True
if
output
.
disaggregated_params
is
not
None
:
# Block the disaggregated params at processor level
pass
chunk
=
DynamoTRTLLMChatCompletionStreamResponse
(
chunk
=
DynamoTRTLLMChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -310,7 +290,7 @@ class ChatProcessor(BaseChatProcessor):
)
)
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
conversation
=
conversation
,
tokenize
=
Fals
e
,
tokenize
=
Tru
e
,
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
documents
=
request
.
documents
,
...
@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -318,16 +298,17 @@ class ChatProcessor(BaseChatProcessor):
**
(
request
.
chat_template_kwargs
or
{}),
**
(
request
.
chat_template_kwargs
or
{}),
)
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
sampling_params
.
_setup
(
self
.
tokenizer
)
sampling_params
.
stop
=
None
return
TRTLLMWorkerRequest
(
return
TRTLLMWorkerRequest
(
id
=
request
.
id
,
id
=
request
.
id
,
model
=
request
.
model
,
model
=
request
.
model
,
prompt
=
prompt
,
sampling_params
=
asdict
(
sampling_params
),
sampling_params
=
asdict
(
sampling_params
),
streaming
=
request
.
stream
,
conversation
=
conversation
,
conversation
=
conversation
,
disaggregated_params
=
request
.
disaggregated_params
,
disaggregated_params
=
request
.
disaggregated_params
,
# NOTE: dont include the first token (e.g. <s>) when searching for a prefix match. We might want to exclude all special tokens at some point.
tokens
=
Tokens
(
tokens
=
prompt
),
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)[
1
:]),
)
)
async
def
postprocess
(
async
def
postprocess
(
...
@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -337,8 +318,6 @@ class ChatProcessor(BaseChatProcessor):
conversation
,
conversation
,
):
):
first_iteration
=
True
first_iteration
=
True
last_text_len
=
0
last_token_ids_len
=
0
async
for
raw_response
in
engine_generator
:
async
for
raw_response
in
engine_generator
:
if
self
.
using_engine_generator
:
if
self
.
using_engine_generator
:
response
=
TRTLLMWorkerResponse
(
response
=
TRTLLMWorkerResponse
(
...
@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor):
...
@@ -351,17 +330,10 @@ class ChatProcessor(BaseChatProcessor):
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
else
:
else
:
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
last_token_ids_len
=
response
.
outputs
[
0
][
"_last_token_ids_len"
]
response
.
outputs
[
0
][
"text"
]
=
self
.
tokenizer
.
decode
(
response
.
outputs
[
0
][
"text"
]
=
self
.
tokenizer
.
decode
(
response
.
outputs
[
0
][
"token_ids"
]
response
.
outputs
[
0
][
"token_ids"
]
[
last_token_ids_len
:]
)
)
# Need to keep track of the last text and token ids length
# to calculate the diff.
# TODO: This is a hack to get the diff. We should identify why
# the diff is not being calculated in the worker.
response
.
outputs
[
0
][
"_last_text_len"
]
=
last_text_len
response
.
outputs
[
0
][
"_last_token_ids_len"
]
=
last_token_ids_len
last_text_len
=
len
(
response
.
outputs
[
0
][
"text"
])
last_token_ids_len
=
len
(
response
.
outputs
[
0
][
"token_ids"
])
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response_data
=
self
.
create_chat_stream_response
(
response_data
=
self
.
create_chat_stream_response
(
...
@@ -380,7 +352,7 @@ class CompletionsProcessor:
...
@@ -380,7 +352,7 @@ class CompletionsProcessor:
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
,
tokenizer
:
TokenizerBase
,
):
):
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -391,20 +363,15 @@ class CompletionsProcessor:
...
@@ -391,20 +363,15 @@ class CompletionsProcessor:
# len(response.outputs) is always 1
# len(response.outputs) is always 1
for
gen_idx
,
output
in
enumerate
(
response
.
outputs
):
for
gen_idx
,
output
in
enumerate
(
response
.
outputs
):
delta_
text
=
output
.
text
_diff
text
=
output
.
text
if
request
.
echo
and
not
echoed
[
gen_idx
]:
if
request
.
echo
and
not
echoed
[
gen_idx
]:
delta_text
=
request
.
prompt
+
delta_text
text
=
request
.
prompt
+
text
echoed
[
gen_idx
]
=
True
choice
=
DynamoTRTLLMCompletionResponseStreamChoice
(
choice
=
DynamoTRTLLMCompletionResponseStreamChoice
(
index
=
gen_idx
,
text
=
text
,
text
=
delta_t
ex
t
,
index
=
output
.
ind
ex
,
stop_reason
=
output
.
stop_reason
,
stop_reason
=
output
.
stop_reason
,
finish_reason
=
output
.
finish_reason
,
finish_reason
=
output
.
finish_reason
,
)
)
if
output
.
disaggregated_params
is
not
None
:
# Block the disagg_params
pass
chunk
=
DynamoTRTLLMCompletionStreamResponse
(
chunk
=
DynamoTRTLLMCompletionStreamResponse
(
model
=
self
.
model
,
model
=
self
.
model
,
choices
=
[
choice
],
choices
=
[
choice
],
...
@@ -423,14 +390,16 @@ class CompletionsProcessor:
...
@@ -423,14 +390,16 @@ class CompletionsProcessor:
)
)
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
sampling_params
.
_setup
(
self
.
tokenizer
)
sampling_params
.
stop
=
None
return
TRTLLMWorkerRequest
(
return
TRTLLMWorkerRequest
(
id
=
request
.
id
,
id
=
request
.
id
,
model
=
request
.
model
,
model
=
request
.
model
,
prompt
=
prompt
,
streaming
=
request
.
stream
,
sampling_params
=
asdict
(
sampling_params
),
sampling_params
=
asdict
(
sampling_params
),
disaggregated_params
=
request
.
disaggregated_params
,
disaggregated_params
=
request
.
disaggregated_params
,
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)
[
1
:]
),
tokens
=
Tokens
(
tokens
=
self
.
tokenizer
.
encode
(
prompt
)),
)
)
async
def
postprocess
(
async
def
postprocess
(
...
@@ -440,8 +409,12 @@ class CompletionsProcessor:
...
@@ -440,8 +409,12 @@ class CompletionsProcessor:
):
):
async
for
raw_response
in
engine_generator
:
async
for
raw_response
in
engine_generator
:
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
response
=
TRTLLMWorkerResponse
.
model_validate_json
(
raw_response
.
data
())
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
last_token_ids_len
=
response
.
outputs
[
0
][
"_last_token_ids_len"
]
response
.
outputs
[
0
][
"text"
]
=
self
.
tokenizer
.
decode
(
response
.
outputs
[
0
][
"token_ids"
][
last_token_ids_len
:]
)
response
.
outputs
=
[
TRTLLMWorkerResponseOutput
(
**
response
.
outputs
[
0
])]
response_data
=
self
.
create_completion_stream_response
(
response_data
=
self
.
create_completion_stream_response
(
request
,
request
,
response
,
response
,
...
...
examples/tensorrt_llm/common/parser.py
View file @
b3da9427
...
@@ -51,6 +51,7 @@ class LLMAPIConfig:
...
@@ -51,6 +51,7 @@ class LLMAPIConfig:
data
=
{
data
=
{
"pytorch_backend_config"
:
self
.
pytorch_backend_config
,
"pytorch_backend_config"
:
self
.
pytorch_backend_config
,
"kv_cache_config"
:
self
.
kv_cache_config
,
"kv_cache_config"
:
self
.
kv_cache_config
,
"skip_tokenizer_init"
:
self
.
skip_tokenizer_init
,
}
}
if
self
.
extra_args
:
if
self
.
extra_args
:
data
.
update
(
self
.
extra_args
)
data
.
update
(
self
.
extra_args
)
...
...
examples/tensorrt_llm/common/protocol.py
View file @
b3da9427
...
@@ -17,7 +17,7 @@ import base64
...
@@ -17,7 +17,7 @@ import base64
import
time
import
time
import
uuid
import
uuid
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
List
,
Literal
,
Optional
,
TypeAlias
,
Union
import
torch
import
torch
from
common.utils
import
ConversationMessage
from
common.utils
import
ConversationMessage
...
@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel):
...
@@ -70,31 +70,38 @@ class TRTLLMWorkerRequest(BaseModel):
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
Field
(
default
=
None
)
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
Field
(
default
=
None
)
@
dataclass
(
slots
=
True
)
class
Logprob
:
"""Holds logprob and vocab rank for a token."""
logprob
:
float
rank
:
Optional
[
int
]
=
None
# List of token_id_to_Logprob dict for prompt or generation texts
TokenLogprobs
:
TypeAlias
=
list
[
dict
[
int
,
Logprob
]]
@
dataclass
@
dataclass
class
TRTLLMWorkerResponseOutput
:
class
TRTLLMWorkerResponseOutput
:
index
:
int
index
:
int
text
:
str
text
:
str
=
""
token_ids
:
list
[
int
]
token_ids
:
Optional
[
List
[
int
]]
=
field
(
default_factory
=
list
)
logprobs
:
Optional
[
List
[
float
]]
=
None
prompt_logprobs
:
Optional
[
List
[
float
]]
=
None
cumulative_logprob
:
Optional
[
float
]
=
None
cumulative_logprob
:
Optional
[
float
]
=
None
logprobs
:
Optional
[
TokenLogprobs
]
=
field
(
default_factory
=
list
)
prompt_logprobs
:
Optional
[
TokenLogprobs
]
=
field
(
default_factory
=
list
)
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"timeout"
,
"cancelled"
]]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
,
"timeout"
,
"cancelled"
]]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]]
=
None
generation_logits
:
Optional
[
torch
.
Tensor
]
=
None
generation_logits
:
Optional
[
torch
.
Tensor
]
=
None
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
_last_text_len
:
int
=
field
(
default
=
0
)
# hidden fields for tracking the diffs
_last_token_ids_len
:
int
=
field
(
default
=
0
)
_last_text_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_last_logprobs_len
:
int
=
field
(
default
=
0
)
_last_token_ids_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_incremental_states
:
Optional
[
dict
]
=
field
(
default
=
None
)
_last_logprobs_len
:
int
=
field
(
default
=
0
,
init
=
True
,
repr
=
False
)
_postprocess_result
:
Optional
[
Any
]
=
field
(
default
=
None
)
_incremental_states
:
Optional
[
dict
]
=
field
(
default
=
None
,
init
=
True
,
repr
=
False
)
# the result of result_handler passed to postprocess workers
text_diff
:
str
=
field
(
default
=
""
)
_postprocess_result
:
Any
=
None
length
:
int
=
field
(
default
=
0
)
def
__post_init__
(
self
):
self
.
text_diff
=
self
.
text
[
self
.
_last_text_len
:]
self
.
length
=
len
(
self
.
token_ids
)
class
TRTLLMWorkerResponse
(
BaseModel
):
class
TRTLLMWorkerResponse
(
BaseModel
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment