Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
47db6ec8
Unverified
Commit
47db6ec8
authored
Nov 12, 2024
by
zifeitong
Committed by
GitHub
Nov 12, 2024
Browse files
[Frontend] Add per-request number of cached token stats (#10174)
parent
176fcb1c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
89 additions
and
23 deletions
+89
-23
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+22
-2
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-0
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+5
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-0
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+6
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+22
-13
vllm/outputs.py
vllm/outputs.py
+13
-6
vllm/sequence.py
vllm/sequence.py
+12
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-0
No files found.
tests/prefix_caching/test_prefix_caching.py
View file @
47db6ec8
...
@@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
...
@@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_mixed_requests
(
def
test_mixed_requests
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -36,11 +37,12 @@ def test_mixed_requests(
...
@@ -36,11 +37,12 @@ def test_mixed_requests(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
cached_position
:
int
,
cached_position
:
int
,
block_size
:
int
,
monkeypatch
,
monkeypatch
,
)
->
None
:
)
->
None
:
"""
"""
Test the case when some sequences have the prefix cache hit
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
the sequence is at among the batch of prefills.
"""
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
override_backend_env_variable
(
monkeypatch
,
backend
)
...
@@ -53,12 +55,30 @@ def test_mixed_requests(
...
@@ -53,12 +55,30 @@ def test_mixed_requests(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
enable_prefix_caching
=
True
,
enable_prefix_caching
=
True
,
block_size
=
block_size
,
)
as
vllm_model
:
)
as
vllm_model
:
# Run the first prompt so the cache is populated
# Run the first prompt so the cache is populated
vllm_outputs
=
vllm_model
.
generate_greedy
([
cached_prompt
],
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
([
cached_prompt
],
max_tokens
)
# Run all the promopts
# Run all the promopts
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
req_outputs
=
vllm_model
.
model
.
generate
(
example_prompts
,
greedy_params
)
# Verify number of cached tokens
for
i
in
range
(
len
(
req_outputs
)):
if
i
==
cached_position
:
expected_num_cached_tokens
=
(
len
(
req_outputs
[
i
].
prompt_token_ids
)
//
block_size
)
*
block_size
else
:
expected_num_cached_tokens
=
0
assert
req_outputs
[
i
].
num_cached_tokens
==
expected_num_cached_tokens
vllm_outputs
=
[
(
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt
+
output
.
outputs
[
0
].
text
)
for
output
in
req_outputs
]
check_outputs_equal
(
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
47db6ec8
...
@@ -540,6 +540,7 @@ def init_app_state(
...
@@ -540,6 +540,7 @@ def init_app_state(
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
,
tool_parser
=
args
.
tool_call_parser
,
enable_prompt_tokens_details
=
args
.
enable_prompt_tokens_details
,
)
if
model_config
.
task
==
"generate"
else
None
)
if
model_config
.
task
==
"generate"
else
None
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
engine_client
,
engine_client
,
...
...
vllm/entrypoints/openai/cli_args.py
View file @
47db6ec8
...
@@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...
@@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
False
,
default
=
False
,
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
help
=
"Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
)
parser
.
add_argument
(
"--enable-prompt-tokens-details"
,
action
=
'store_true'
,
default
=
False
,
help
=
"If set to True, enable prompt_tokens_details in usage."
)
return
parser
return
parser
...
...
vllm/entrypoints/openai/protocol.py
View file @
47db6ec8
...
@@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
...
@@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
class
PromptTokenUsageInfo
(
OpenAIBaseModel
):
cached_tokens
:
Optional
[
int
]
=
None
class
UsageInfo
(
OpenAIBaseModel
):
class
UsageInfo
(
OpenAIBaseModel
):
prompt_tokens
:
int
=
0
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
prompt_tokens_details
:
Optional
[
PromptTokenUsageInfo
]
=
None
class
RequestResponseMetadata
(
BaseModel
):
class
RequestResponseMetadata
(
BaseModel
):
...
...
vllm/entrypoints/openai/run_batch.py
View file @
47db6ec8
...
@@ -78,6 +78,11 @@ def parse_args():
...
@@ -78,6 +78,11 @@ def parse_args():
help
=
"Port number for the Prometheus metrics server "
help
=
"Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set)."
,
"(only needed if enable-metrics is set)."
,
)
)
parser
.
add_argument
(
"--enable-prompt-tokens-details"
,
action
=
'store_true'
,
default
=
False
,
help
=
"If set to True, enable prompt_tokens_details in usage."
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -217,6 +222,7 @@ async def main(args):
...
@@ -217,6 +222,7 @@ async def main(args):
prompt_adapters
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
None
,
chat_template
=
None
,
enable_prompt_tokens_details
=
args
.
enable_prompt_tokens_details
,
)
if
model_config
.
task
==
"generate"
else
None
)
if
model_config
.
task
==
"generate"
else
None
openai_serving_embedding
=
OpenAIServingEmbedding
(
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
engine
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
47db6ec8
...
@@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
RequestResponseMetadata
,
DeltaToolCall
,
ErrorResponse
,
FunctionCall
,
PromptTokenUsageInfo
,
ToolCall
,
UsageInfo
)
RequestResponseMetadata
,
ToolCall
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
LoRAModulePath
,
OpenAIServing
,
OpenAIServing
,
...
@@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
return_tokens_as_token_ids
:
bool
=
False
,
return_tokens_as_token_ids
:
bool
=
False
,
enable_auto_tools
:
bool
=
False
,
enable_auto_tools
:
bool
=
False
,
tool_parser
:
Optional
[
str
]
=
None
):
tool_parser
:
Optional
[
str
]
=
None
,
enable_prompt_tokens_details
:
bool
=
False
):
super
().
__init__
(
engine_client
=
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
base_model_paths
=
base_model_paths
,
base_model_paths
=
base_model_paths
,
...
@@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
f
"tool_parser:'
{
tool_parser
}
' which has not "
f
"tool_parser:'
{
tool_parser
}
' which has not "
"been registered"
)
from
e
"been registered"
)
from
e
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
self
,
request
:
ChatCompletionRequest
,
request
:
ChatCompletionRequest
,
...
@@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens
=
[
0
]
*
num_choices
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
num_prompt_tokens
=
0
num_prompt_tokens
=
0
num_cached_tokens
=
None
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
...
@@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
# the result_generator, it needs to be sent as the FIRST
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
# response (by the try...catch).
if
first_iteration
:
if
first_iteration
:
num_cached_tokens
=
res
.
num_cached_tokens
# Send first response for each request.n (index) with
# Send first response for each request.n (index) with
# the role
# the role
role
=
self
.
get_chat_request_role
(
request
)
role
=
self
.
get_chat_request_role
(
request
)
...
@@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
# is sent, send the usage
if
include_usage
:
if
include_usage
:
completion_tokens
=
sum
(
previous_num_tokens
)
completion_tokens
=
sum
(
previous_num_tokens
)
final_usage
=
UsageInfo
(
final_usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
total_tokens
=
num_prompt_tokens
+
completion_tokens
,
completion_tokens
)
)
if
self
.
enable_prompt_tokens_details
and
num_cached_tokens
:
final_usage
.
prompt_tokens_details
=
PromptTokenUsageInfo
(
cached_tokens
=
num_cached_tokens
)
final_usage_chunk
=
ChatCompletionStreamResponse
(
final_usage_chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
...
@@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
num_prompt_tokens
+=
len
(
final_res
.
encoder_prompt_token_ids
)
num_prompt_tokens
+=
len
(
final_res
.
encoder_prompt_token_ids
)
num_generated_tokens
=
sum
(
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
num_generated_tokens
,
completion_tokens
=
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
num_generated_tokens
)
)
if
self
.
enable_prompt_tokens_details
and
final_res
.
num_cached_tokens
:
usage
.
prompt_tokens_details
=
PromptTokenUsageInfo
(
cached_tokens
=
final_res
.
num_cached_tokens
)
request_metadata
.
final_usage_info
=
usage
request_metadata
.
final_usage_info
=
usage
...
...
vllm/outputs.py
View file @
47db6ec8
...
@@ -83,10 +83,11 @@ class RequestOutput:
...
@@ -83,10 +83,11 @@ class RequestOutput:
finished: Whether the whole request is finished.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request;
encoder_prompt: The encoder prompt string of the request.
None if decoder-only
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt;
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -101,6 +102,7 @@ class RequestOutput:
...
@@ -101,6 +102,7 @@ class RequestOutput:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
encoder_prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
encoder_prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
num_cached_tokens
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt
=
prompt
self
.
prompt
=
prompt
...
@@ -112,6 +114,7 @@ class RequestOutput:
...
@@ -112,6 +114,7 @@ class RequestOutput:
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
encoder_prompt
=
encoder_prompt
self
.
encoder_prompt
=
encoder_prompt
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
num_cached_tokens
=
num_cached_tokens
@
classmethod
@
classmethod
def
new
(
def
new
(
...
@@ -192,6 +195,8 @@ class RequestOutput:
...
@@ -192,6 +195,8 @@ class RequestOutput:
outputs
=
[]
outputs
=
[]
include_prompt
=
True
include_prompt
=
True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens
=
None
for
i
,
seq
in
enumerate
(
top_n_seqs
):
for
i
,
seq
in
enumerate
(
top_n_seqs
):
output_text
=
seq
.
get_output_text_to_return
(
output_text
=
seq
.
get_output_text_to_return
(
text_buffer_length
,
delta
)
text_buffer_length
,
delta
)
...
@@ -199,6 +204,7 @@ class RequestOutput:
...
@@ -199,6 +204,7 @@ class RequestOutput:
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
num_output_tokens
=
1
if
isinstance
(
output_token_ids
,
num_output_tokens
=
1
if
isinstance
(
output_token_ids
,
int
)
else
len
(
output_token_ids
)
int
)
else
len
(
output_token_ids
)
num_cached_tokens
=
seq
.
data
.
get_num_cached_tokens
()
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
...
@@ -272,7 +278,7 @@ class RequestOutput:
...
@@ -272,7 +278,7 @@ class RequestOutput:
init_args
=
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
init_args
=
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
,
seq_group
.
metrics
,
prompt_logprobs
,
outputs
,
finished
,
seq_group
.
metrics
,
seq_group
.
lora_request
,
encoder_prompt
,
seq_group
.
lora_request
,
encoder_prompt
,
encoder_prompt_token_ids
)
encoder_prompt_token_ids
,
num_cached_tokens
)
if
use_cache
:
if
use_cache
:
request_output
=
seq_group
.
cached_request_output
request_output
=
seq_group
.
cached_request_output
...
@@ -293,7 +299,8 @@ class RequestOutput:
...
@@ -293,7 +299,8 @@ class RequestOutput:
f
"outputs=
{
self
.
outputs
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"finished=
{
self
.
finished
}
, "
f
"finished=
{
self
.
finished
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"lora_request=
{
self
.
lora_request
}
)"
)
f
"lora_request=
{
self
.
lora_request
}
, "
f
"num_cached_tokens=
{
self
.
num_cached_tokens
}
)"
)
class
EmbeddingRequestOutput
:
class
EmbeddingRequestOutput
:
...
...
vllm/sequence.py
View file @
47db6ec8
...
@@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
...
@@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
_num_computed_tokens
:
int
=
0
_num_computed_tokens
:
int
=
0
# The number of tokens with prefix cache hit.
_num_cached_tokens
:
int
=
0
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
_cached_all_token_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_cached_all_token_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
...
@@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
...
@@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
if
self
.
get_num_uncomputed_tokens
()
==
0
:
if
self
.
get_num_uncomputed_tokens
()
==
0
:
self
.
_stage
=
SequenceStage
.
DECODE
self
.
_stage
=
SequenceStage
.
DECODE
def
get_num_cached_tokens
(
self
)
->
int
:
"""Return the number of tokens with prefix cache hit."""
return
self
.
_num_cached_tokens
def
update_num_cached_tokens
(
self
,
num_cached_tokens
:
int
):
"""Update the number of tokens with prefix cache hit."""
self
.
_num_cached_tokens
=
num_cached_tokens
def
reset_state_for_recompute
(
self
)
->
None
:
def
reset_state_for_recompute
(
self
)
->
None
:
"""Reset the number of computed tokens from this sequence. It is
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from
supposed to be called when a sequence needs to be started from
...
@@ -379,7 +389,7 @@ class SequenceData(msgspec.Struct,
...
@@ -379,7 +389,7 @@ class SequenceData(msgspec.Struct,
class
Sequence
:
class
Sequence
:
"""Stores the data, status, and block information of a sequence.
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the :data:`DecoderOnlyInputs`
The sequence is constructed from the :data:`DecoderOnlyInputs`
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the :code:`inputs` constructor argument.
instance passed in through the :code:`inputs` constructor argument.
...
@@ -906,7 +916,7 @@ class SequenceGroupMetadata(
...
@@ -906,7 +916,7 @@ class SequenceGroupMetadata(
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
unless you are working with an encoder/decoder
model.
model.
cross_block_table: Optional cross-attention block table associated
cross_block_table: Optional cross-attention block table associated
...
...
vllm/worker/model_runner.py
View file @
47db6ec8
...
@@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# this may be larger than the sequence length if chunked
# this may be larger than the sequence length if chunked
# prefill is enabled.
# prefill is enabled.
prefix_cache_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prefix_cache_len
=
len
(
computed_block_nums
)
*
self
.
block_size
seq_group_metadata
.
seq_data
[
inter_data
.
seq_ids
[
seq_idx
]].
update_num_cached_tokens
(
prefix_cache_len
)
# The number of so far computed prompt tokens in this sequence.
# The number of so far computed prompt tokens in this sequence.
context_len
=
inter_data
.
context_lens
[
seq_idx
]
context_len
=
inter_data
.
context_lens
[
seq_idx
]
# The total number of prompt tokens in this sequence.
# The total number of prompt tokens in this sequence.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment