Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
225 additions
and
220 deletions
+225
-220
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+4
-3
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
.../entrypoints/openai/tool_parsers/internlm2_tool_parser.py
+3
-2
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+7
-6
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+3
-2
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+3
-3
vllm/entrypoints/score_utils.py
vllm/entrypoints/score_utils.py
+7
-7
vllm/envs.py
vllm/envs.py
+4
-4
vllm/forward_context.py
vllm/forward_context.py
+3
-3
vllm/logger.py
vllm/logger.py
+1
-1
vllm/logits_process.py
vllm/logits_process.py
+8
-8
vllm/outputs.py
vllm/outputs.py
+12
-12
vllm/sampling_params.py
vllm/sampling_params.py
+26
-27
vllm/sequence.py
vllm/sequence.py
+66
-66
vllm/tracing.py
vllm/tracing.py
+2
-1
vllm/utils.py
vllm/utils.py
+38
-38
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+9
-9
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+10
-11
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-7
No files found.
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
cf069aa8
...
...
@@ -2,7 +2,8 @@
import
json
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser):
self
.
model_tokenizer
=
self
.
model_tokenizer
.
tokenizer
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
tool_call_start_token
:
str
=
"<tool_call>"
...
...
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
json
from
typing
import
Dict
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser):
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try
:
tool_call_arr
:
D
ict
=
partial_json_parser
.
loads
(
tool_call_arr
:
d
ict
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
...
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
View file @
cf069aa8
...
...
@@ -2,7 +2,8 @@
import
json
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser):
)
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
tool_calls_start_token
:
str
=
"<tool_calls>"
...
...
@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try
:
tool_call_arr
:
L
ist
[
D
ict
]
=
partial_json_parser
.
loads
(
tool_call_arr
:
l
ist
[
d
ict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
...
@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser):
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
View file @
cf069aa8
...
...
@@ -2,8 +2,9 @@
import
json
import
re
from
collections.abc
import
Sequence
from
json
import
JSONDecoder
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
typing
import
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"<|python_tag|>"
self
.
bot_token_id
=
tokenizer
.
encode
(
self
.
bot_token
,
...
...
@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser):
start_idx
+=
end_idx
+
len
(
'; '
)
function_call_arr
.
append
(
obj
)
tool_calls
:
L
ist
[
ToolCall
]
=
[
tool_calls
:
l
ist
[
ToolCall
]
=
[
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
...
...
@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser):
return
None
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
cf069aa8
...
...
@@ -2,9 +2,10 @@
import
json
import
re
from
collections.abc
import
Sequence
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
typing
import
Union
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
...
...
@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser):
function_call_arr
=
json
.
loads
(
raw_tool_call
)
# Tool Call
tool_calls
:
L
ist
[
MistralToolCall
]
=
[
tool_calls
:
l
ist
[
MistralToolCall
]
=
[
MistralToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
...
...
@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try
:
tool_call_arr
:
L
ist
[
D
ict
]
=
partial_json_parser
.
loads
(
tool_call_arr
:
l
ist
[
d
ict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
...
@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser):
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
cf069aa8
...
...
@@ -3,7 +3,8 @@
import
ast
import
json
import
re
from
typing
import
Any
,
Sequence
,
Tuple
,
Union
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
from
transformers
import
PreTrainedTokenizerBase
...
...
@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments
=
json
.
dumps
(
arguments
)))
def
_make_valid_python
(
text
:
str
)
->
Union
[
T
uple
[
str
,
str
],
None
]:
def
_make_valid_python
(
text
:
str
)
->
Union
[
t
uple
[
str
,
str
],
None
]:
bracket_stack
=
[]
for
index
,
char
in
enumerate
(
text
):
if
char
in
{
"["
,
"("
,
"{"
}:
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
json
from
json
import
JSONDecodeError
,
JSONDecoder
from
typing
import
Any
,
List
,
Tuple
from
typing
import
Any
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
...
...
@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
return
diff
def
find_all_indices
(
string
:
str
,
substring
:
str
)
->
L
ist
[
int
]:
def
find_all_indices
(
string
:
str
,
substring
:
str
)
->
l
ist
[
int
]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
...
...
@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]:
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
T
uple
[
Any
,
int
]:
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
t
uple
[
Any
,
int
]:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
except
JSONDecodeError
as
e
:
...
...
vllm/entrypoints/score_utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Union
from
typing
import
Union
from
torch.nn
import
CosineSimilarity
...
...
@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
def
_cosine_similarity
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
embed_1
:
L
ist
[
PoolingRequestOutput
],
embed_2
:
L
ist
[
PoolingRequestOutput
],
)
->
L
ist
[
PoolingRequestOutput
]:
embed_1
:
l
ist
[
PoolingRequestOutput
],
embed_2
:
l
ist
[
PoolingRequestOutput
],
)
->
l
ist
[
PoolingRequestOutput
]:
scorer
=
CosineSimilarity
(
0
)
scores
:
Union
[
L
ist
[
PoolingRequestOutput
]]
=
[]
scores
:
Union
[
l
ist
[
PoolingRequestOutput
]]
=
[]
for
emb_1
,
emb_2
in
zip
(
embed_1
,
embed_2
):
pair_score
=
scorer
(
emb_1
.
outputs
.
data
,
emb_2
.
outputs
.
data
)
...
...
@@ -38,8 +38,8 @@ def _cosine_similarity(
def
_validate_score_input_lens
(
texts_1
:
Union
[
L
ist
[
str
],
L
ist
[
dict
]],
texts_2
:
Union
[
L
ist
[
str
],
L
ist
[
dict
]],
texts_1
:
Union
[
l
ist
[
str
],
l
ist
[
dict
]],
texts_2
:
Union
[
l
ist
[
str
],
l
ist
[
dict
]],
):
if
len
(
texts_1
)
>
1
and
len
(
texts_1
)
!=
len
(
texts_2
):
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
...
...
vllm/envs.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
os
import
tempfile
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
...
...
@@ -67,12 +67,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_TIMEOUT
:
int
=
10000
# ms
VLLM_PLUGINS
:
Optional
[
L
ist
[
str
]]
=
None
VLLM_PLUGINS
:
Optional
[
l
ist
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
L
ist
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
l
ist
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
...
...
@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# begin-env-vars-definition
environment_variables
:
D
ict
[
str
,
Callable
[[],
Any
]]
=
{
environment_variables
:
d
ict
[
str
,
Callable
[[],
Any
]]
=
{
# ================== Installation Time Env Vars ==================
...
...
vllm/forward_context.py
View file @
cf069aa8
...
...
@@ -4,7 +4,7 @@ import time
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch.distributed
as
dist
...
...
@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
class
ForwardContext
:
# copy from vllm_config.compilation_config.static_forward_context
attn_layers
:
D
ict
[
str
,
Any
]
attn_layers
:
d
ict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
num_tokens_across_dp
:
Optional
[
L
ist
[
int
]]
=
None
# set dynamically for each forward pass
l
ist
[
int
]]
=
None
# set dynamically for each forward pass
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
...
vllm/logger.py
View file @
cf069aa8
...
...
@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None:
custom_config
=
json
.
loads
(
file
.
read
())
if
not
isinstance
(
custom_config
,
dict
):
raise
ValueError
(
"Invalid logging config. Expected
D
ict, got %s."
,
raise
ValueError
(
"Invalid logging config. Expected
d
ict, got %s."
,
type
(
custom_config
).
__name__
)
logging_config
=
custom_config
...
...
vllm/logits_process.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Tuple
,
Union
from
typing
import
Callable
,
Union
import
torch
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
LogitsProcessor
=
Union
[
Callable
[[
L
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
],
Callable
[[
L
ist
[
int
],
L
ist
[
int
],
torch
.
Tensor
],
LogitsProcessor
=
Union
[
Callable
[[
l
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
],
Callable
[[
l
ist
[
int
],
l
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
...
...
@@ -17,9 +17,9 @@ to sample from."""
def
get_bad_words_logits_processors
(
bad_words
:
L
ist
[
str
],
tokenizer
:
AnyTokenizer
)
->
L
ist
[
LogitsProcessor
]:
bad_words_ids
:
L
ist
[
L
ist
[
int
]]
=
list
()
bad_words
:
l
ist
[
str
],
tokenizer
:
AnyTokenizer
)
->
l
ist
[
LogitsProcessor
]:
bad_words_ids
:
l
ist
[
l
ist
[
int
]]
=
list
()
for
bad_word
in
bad_words
:
# To prohibit words both at the beginning
...
...
@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT
=
float
(
"-inf"
)
_NEUTRAL_LOGIT
=
0.0
def
__init__
(
self
,
bad_words_ids
:
L
ist
[
L
ist
[
int
]]):
def
__init__
(
self
,
bad_words_ids
:
l
ist
[
l
ist
[
int
]]):
self
.
bad_words_ids
=
bad_words_ids
self
.
word_bias
:
torch
.
FloatTensor
=
None
def
__call__
(
self
,
past_tokens_ids
:
Union
[
L
ist
[
int
],
T
uple
[
int
]],
past_tokens_ids
:
Union
[
l
ist
[
int
],
t
uple
[
int
]],
logits
:
torch
.
FloatTensor
,
)
->
torch
.
Tensor
:
if
self
.
word_bias
is
None
:
...
...
vllm/outputs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
time
from
collections.abc
import
MutableSequence
from
collections.abc
import
Sequence
as
GenericSequence
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Generic
,
List
,
MutableSequence
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
Generic
,
Optional
,
Union
import
torch
from
typing_extensions
import
TypeVar
,
deprecated
...
...
@@ -109,14 +109,14 @@ class RequestOutput:
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
L
ist
[
int
]],
prompt_token_ids
:
Optional
[
l
ist
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
L
ist
[
CompletionOutput
],
outputs
:
l
ist
[
CompletionOutput
],
finished
:
bool
,
metrics
:
Optional
[
RequestMetrics
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
encoder_prompt_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
encoder_prompt_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
num_cached_tokens
:
Optional
[
int
]
=
None
,
*
,
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
,
...
...
@@ -139,9 +139,9 @@ class RequestOutput:
cls
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
L
ist
[
int
]],
prompt_token_ids
:
Optional
[
l
ist
[
int
]],
text
:
str
,
token_ids
:
L
ist
[
int
],
token_ids
:
l
ist
[
int
],
logprobs
:
Optional
[
SampleLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
cumulative_logprob
:
Optional
[
float
],
...
...
@@ -189,7 +189,7 @@ class RequestOutput:
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
use_cache
:
bool
,
seq_id_to_seq_group
:
D
ict
[
str
,
SequenceGroupBase
]
seq_id_to_seq_group
:
d
ict
[
str
,
SequenceGroupBase
]
)
->
Optional
[
"RequestOutput"
]:
finished
=
seq_group
.
is_finished
()
...
...
@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]):
Args:
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (
L
ist[int]): A list of token IDs used in the prompt.
prompt_token_ids (
l
ist[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed.
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
_O
,
prompt_token_ids
:
L
ist
[
int
],
finished
:
bool
):
prompt_token_ids
:
l
ist
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
finished
=
finished
...
...
@@ -407,7 +407,7 @@ class RequestOutputFactory:
@
staticmethod
def
create
(
seq_group
:
SequenceGroup
,
seq_id_to_seq_group
:
D
ict
[
str
,
SequenceGroupBase
],
seq_id_to_seq_group
:
d
ict
[
str
,
SequenceGroupBase
],
use_cache
:
bool
=
False
):
if
seq_group
.
pooled_data
is
not
None
:
return
PoolingRequestOutput
.
from_seq_group
(
seq_group
)
...
...
vllm/sampling_params.py
View file @
cf069aa8
...
...
@@ -4,11 +4,10 @@ import copy
from
dataclasses
import
dataclass
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
typing
import
An
y
,
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
An
notated
,
Any
,
Optional
,
Union
import
msgspec
from
pydantic
import
BaseModel
from
typing_extensions
import
Annotated
from
vllm.logger
import
init_logger
from
vllm.logits_process
import
LogitsProcessor
...
...
@@ -29,9 +28,9 @@ class SamplingType(IntEnum):
@
dataclass
class
GuidedDecodingParams
:
"""One of these fields will be used to build a logit processor."""
json
:
Optional
[
Union
[
str
,
D
ict
]]
=
None
json
:
Optional
[
Union
[
str
,
d
ict
]]
=
None
regex
:
Optional
[
str
]
=
None
choice
:
Optional
[
L
ist
[
str
]]
=
None
choice
:
Optional
[
l
ist
[
str
]]
=
None
grammar
:
Optional
[
str
]
=
None
json_object
:
Optional
[
bool
]
=
None
"""These are other options that can be set"""
...
...
@@ -40,9 +39,9 @@ class GuidedDecodingParams:
@
staticmethod
def
from_optional
(
json
:
Optional
[
Union
[
D
ict
,
BaseModel
,
str
]]
=
None
,
json
:
Optional
[
Union
[
d
ict
,
BaseModel
,
str
]]
=
None
,
regex
:
Optional
[
str
]
=
None
,
choice
:
Optional
[
L
ist
[
str
]]
=
None
,
choice
:
Optional
[
l
ist
[
str
]]
=
None
,
grammar
:
Optional
[
str
]
=
None
,
json_object
:
Optional
[
bool
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
...
...
@@ -72,7 +71,7 @@ class GuidedDecodingParams:
"""
return
(
self
.
backend
or
""
).
split
(
":"
)[
0
]
def
backend_options
(
self
)
->
L
ist
[
str
]:
def
backend_options
(
self
)
->
l
ist
[
str
]:
"""Return the backend options as a list of strings."""
if
not
self
.
backend
or
":"
not
in
self
.
backend
:
return
[]
...
...
@@ -144,12 +143,12 @@ class SamplingParams(
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop:
L
ist of strings that stop the generation when they are generated.
stop:
l
ist of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids:
L
ist of tokens that stop the generation when they are
stop_token_ids:
l
ist of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
bad_words:
L
ist of words that are not allowed to be generated.
bad_words:
l
ist of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
can complete the sequence.
...
...
@@ -172,7 +171,7 @@ class SamplingParams(
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors:
L
ist of functions that modify logits based on
logits_processors:
l
ist of functions that modify logits based on
previously generated tokens, and optionally prompt tokens as
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
...
...
@@ -198,9 +197,9 @@ class SamplingParams(
top_k
:
int
=
-
1
min_p
:
float
=
0.0
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
ignore_eos
:
bool
=
False
max_tokens
:
Optional
[
int
]
=
16
min_tokens
:
int
=
0
...
...
@@ -212,8 +211,8 @@ class SamplingParams(
detokenize
:
bool
=
True
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
# Optional[
L
ist[LogitsProcessor]] type. We use Any here because
# Optional[
L
ist[LogitsProcessor]] type is not supported by msgspec.
# Optional[
l
ist[LogitsProcessor]] type. We use Any here because
# Optional[
l
ist[LogitsProcessor]] type is not supported by msgspec.
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
...
...
@@ -222,12 +221,12 @@ class SamplingParams(
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length
:
int
=
0
_all_stop_token_ids
:
S
et
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
_all_stop_token_ids
:
s
et
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
# Fields used to construct logits processors
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
logit_bias
:
Optional
[
D
ict
[
int
,
float
]]
=
None
allowed_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
logit_bias
:
Optional
[
d
ict
[
int
,
float
]]
=
None
allowed_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
@
staticmethod
def
from_optional
(
...
...
@@ -241,9 +240,9 @@ class SamplingParams(
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
,
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
Optional
[
int
]
=
16
,
...
...
@@ -253,13 +252,13 @@ class SamplingParams(
detokenize
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
L
ist
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
l
ist
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
,
logit_bias
:
Optional
[
Union
[
D
ict
[
int
,
float
],
D
ict
[
str
,
float
]]]
=
None
,
allowed_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
logit_bias
:
Optional
[
Union
[
d
ict
[
int
,
float
],
d
ict
[
str
,
float
]]]
=
None
,
allowed_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
)
->
"SamplingParams"
:
if
logit_bias
is
not
None
:
# Convert token_id to integer
...
...
@@ -435,7 +434,7 @@ class SamplingParams(
def
update_from_generation_config
(
self
,
generation_config
:
D
ict
[
str
,
Any
],
generation_config
:
d
ict
[
str
,
Any
],
model_eos_token_id
:
Optional
[
int
]
=
None
)
->
None
:
"""Update if there are non-default values from generation_config"""
...
...
@@ -468,7 +467,7 @@ class SamplingParams(
return
SamplingType
.
RANDOM
@
property
def
all_stop_token_ids
(
self
)
->
S
et
[
int
]:
def
all_stop_token_ids
(
self
)
->
s
et
[
int
]:
return
self
.
_all_stop_token_ids
def
clone
(
self
)
->
"SamplingParams"
:
...
...
vllm/sequence.py
View file @
cf069aa8
...
...
@@ -5,11 +5,11 @@ import enum
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
collections.abc
import
Sequence
as
GenericSequence
from
dataclasses
import
dataclass
,
field
from
functools
import
reduce
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
msgspec
import
torch
...
...
@@ -50,9 +50,9 @@ class Logprob:
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
L
ist
[
Optional
[
D
ict
[
int
,
Logprob
]]]
PromptLogprobs
=
l
ist
[
Optional
[
d
ict
[
int
,
Logprob
]]]
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
L
ist
[
D
ict
[
int
,
Logprob
]]
SampleLogprobs
=
l
ist
[
d
ict
[
int
,
Logprob
]]
class
SequenceStatus
(
enum
.
IntEnum
):
...
...
@@ -129,7 +129,7 @@ class SequenceDataDelta(
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids
:
L
ist
[
int
]
new_output_token_ids
:
l
ist
[
int
]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob
:
float
# Overwriting existing `num_computed_tokens`.
...
...
@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
# NOTE: we cannot use Union[
L
ist, array] because msgspec cannot support
# NOTE: we cannot use Union[
l
ist, array] because msgspec cannot support
# union of 2 list types.
_prompt_token_ids
:
array
_output_token_ids
:
array
=
msgspec
.
field
(
...
...
@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
### The below fields should not be passed as an argument ###
_cumulative_logprob
:
float
=
0.0
_prompt_token_ids_tuple
:
T
uple
[
int
,
_prompt_token_ids_tuple
:
t
uple
[
int
,
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens
:
int
=
0
# The number of tokens with prefix cache hit.
_num_cached_tokens
:
int
=
0
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
_cached_all_token_ids
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_cached_all_token_ids
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_new_appended_tokens
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
# It is used to compute mrope_position_ids.
_mrope_position_delta
:
Optional
[
int
]
=
None
@
staticmethod
def
from_prompt_token_counts
(
*
token_counts
:
T
uple
[
int
,
int
])
->
"SequenceData"
:
*
token_counts
:
t
uple
[
int
,
int
])
->
"SequenceData"
:
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
...
...
@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
def
__post_init__
(
self
)
->
None
:
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
self
.
_prompt_token_ids_tuple
:
T
uple
[
int
,
...]
=
tuple
(
self
.
_prompt_token_ids_tuple
:
t
uple
[
int
,
...]
=
tuple
(
self
.
_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
assert
isinstance
(
self
.
_prompt_token_ids
,
array
)
assert
isinstance
(
self
.
_output_token_ids
,
array
)
self
.
_cached_all_token_ids
:
L
ist
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_cached_all_token_ids
:
l
ist
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
@
property
...
...
@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
return
self
.
_cumulative_logprob
@
property
def
prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
_prompt_token_ids_tuple
@
prompt_token_ids
.
setter
...
...
@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
return
self
.
_prompt_token_ids
@
property
def
output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
...
...
@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
def
get_token_ids
(
self
)
->
L
ist
[
int
]:
def
get_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
_cached_all_token_ids
def
get_prefix_token_ids
(
self
,
num_tokens
:
int
)
->
T
uple
[
T
uple
[
int
,
...],
Optional
[
T
uple
[
int
,
...]]]:
)
->
t
uple
[
t
uple
[
int
,
...],
Optional
[
t
uple
[
int
,
...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length
=
self
.
get_prompt_len
()
if
num_tokens
>
prompt_length
:
...
...
@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
get_prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
prompt_token_ids
def
get_output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
output_token_ids
def
get_delta_and_reset
(
self
)
->
SequenceDataDelta
:
...
...
@@ -432,7 +432,7 @@ class Sequence:
self
.
prefix_offset
=
0
self
.
read_offset
=
0
# Input + output tokens
self
.
tokens
:
Optional
[
L
ist
[
str
]]
=
None
self
.
tokens
:
Optional
[
l
ist
[
str
]]
=
None
@
property
def
n_blocks
(
self
)
->
int
:
...
...
@@ -443,7 +443,7 @@ class Sequence:
return
self
.
inputs
.
prompt
@
property
def
prompt_token_ids
(
self
)
->
L
ist
[
int
]:
def
prompt_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
inputs
.
prompt_token_ids
@
property
...
...
@@ -451,7 +451,7 @@ class Sequence:
return
self
.
inputs
.
prompt_embeds
@
property
def
token_type_ids
(
self
)
->
L
ist
[
int
]:
def
token_type_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
inputs
.
token_type_ids
@
property
...
...
@@ -463,7 +463,7 @@ class Sequence:
return
self
.
inputs
.
multi_modal_placeholders
@
property
def
mm_processor_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mm_processor_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
return
self
.
inputs
.
mm_processor_kwargs
@
property
...
...
@@ -548,7 +548,7 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_state_for_recompute
()
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
D
ict
[
int
,
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
d
ict
[
int
,
Logprob
])
->
None
:
assert
token_id
in
logprobs
self
.
output_logprobs
.
append
(
logprobs
)
...
...
@@ -563,16 +563,16 @@ class Sequence:
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
def
get_token_ids
(
self
)
->
L
ist
[
int
]:
def
get_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
data
.
get_token_ids
()
def
get_prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
data
.
get_output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
...
...
@@ -644,7 +644,7 @@ class SequenceGroup:
def
__init__
(
self
,
request_id
:
str
,
seqs
:
L
ist
[
Sequence
],
seqs
:
l
ist
[
Sequence
],
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -686,7 +686,7 @@ class SequenceGroup:
return
self
.
first_seq
.
prompt
@
property
def
prompt_token_ids
(
self
)
->
L
ist
[
int
]:
def
prompt_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
first_seq
.
prompt_token_ids
@
property
...
...
@@ -698,7 +698,7 @@ class SequenceGroup:
if
self
.
encoder_seq
is
not
None
else
None
)
@
property
def
encoder_prompt_token_ids
(
self
)
->
Optional
[
L
ist
[
int
]]:
def
encoder_prompt_token_ids
(
self
)
->
Optional
[
l
ist
[
int
]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
...
...
@@ -706,7 +706,7 @@ class SequenceGroup:
if
self
.
encoder_seq
is
not
None
else
None
)
@
property
def
token_type_ids
(
self
)
->
Optional
[
L
ist
[
int
]]:
def
token_type_ids
(
self
)
->
Optional
[
l
ist
[
int
]]:
return
self
.
first_seq
.
token_type_ids
@
property
...
...
@@ -726,7 +726,7 @@ class SequenceGroup:
return
{}
@
property
def
mm_processor_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mm_processor_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
if
self
.
first_seq
.
multi_modal_data
:
return
self
.
first_seq
.
mm_processor_kwargs
elif
self
.
encoder_seq
is
not
None
:
...
...
@@ -823,7 +823,7 @@ class SequenceGroup:
def
get_seqs
(
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
L
ist
[
Sequence
]:
)
->
l
ist
[
Sequence
]:
if
status
is
None
:
return
self
.
seqs
...
...
@@ -838,7 +838,7 @@ class SequenceGroup:
def
get_encoder_seq
(
self
)
->
Optional
[
Sequence
]:
return
self
.
encoder_seq
def
get_finished_seqs
(
self
)
->
L
ist
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
l
ist
[
Sequence
]:
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
...
...
@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta
:
D
ict
[
int
,
SequenceDataDelta
]
seq_data_delta
:
d
ict
[
int
,
SequenceDataDelta
]
request_id
:
str
block_tables
:
D
ict
[
int
,
L
ist
[
int
]]
block_tables
:
d
ict
[
int
,
l
ist
[
int
]]
is_prompt
:
bool
do_sample
:
bool
=
True
token_chunk_size
:
Optional
[
int
]
=
None
computed_block_nums
:
Optional
[
L
ist
[
int
]]
=
None
computed_block_nums
:
Optional
[
l
ist
[
int
]]
=
None
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
default_factory
=
lambda
:
SequenceGroupState
())
...
...
@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
request_id
:
str
is_prompt
:
bool
seq_data
:
D
ict
[
int
,
SequenceData
]
seq_data
:
d
ict
[
int
,
SequenceData
]
sampling_params
:
Optional
[
SamplingParams
]
block_tables
:
D
ict
[
int
,
L
ist
[
int
]]
block_tables
:
d
ict
[
int
,
l
ist
[
int
]]
do_sample
:
bool
=
True
pooling_params
:
Optional
[
PoolingParams
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
computed_block_nums
:
Optional
[
L
ist
[
int
]]
=
None
computed_block_nums
:
Optional
[
l
ist
[
int
]]
=
None
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
default_factory
=
lambda
:
SequenceGroupState
())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
token_type_ids
:
Optional
[
L
ist
[
int
]]
=
None
token_type_ids
:
Optional
[
l
ist
[
int
]]
=
None
multi_modal_data
:
Optional
[
Any
]
=
None
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
cross_block_table
:
Optional
[
L
ist
[
int
]]
=
None
cross_block_table
:
Optional
[
l
ist
[
int
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
token_chunk_size
:
Optional
[
int
]
=
None
...
...
@@ -1042,7 +1042,7 @@ class SequenceOutput(
"""
parent_seq_id
:
int
output_token
:
int
logprobs
:
D
ict
[
int
,
Logprob
]
logprobs
:
d
ict
[
int
,
Logprob
]
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
...
...
@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
array_like
=
True
):
# type: ignore[call-arg]
"""The model output associated with a completion sequence group."""
__metaclass__
=
SequenceGroupOutput
samples
:
L
ist
[
SequenceOutput
]
samples
:
l
ist
[
SequenceOutput
]
# Prompt logprob for each prompt query token.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
...
...
@@ -1119,7 +1119,7 @@ class IntermediateTensors:
contains the hidden states and residuals for a request.
"""
tensors
:
D
ict
[
str
,
torch
.
Tensor
]
tensors
:
d
ict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
# manually define this function, so that
...
...
@@ -1155,7 +1155,7 @@ class PoolerOutput(
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs
:
L
ist
[
PoolingSequenceGroupOutput
]
outputs
:
l
ist
[
PoolingSequenceGroupOutput
]
def
__getitem__
(
self
,
idx
:
int
)
->
PoolingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
...
...
@@ -1172,7 +1172,7 @@ class PoolerOutput(
def
get_all_seq_ids
(
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
])
->
L
ist
[
int
]:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
])
->
l
ist
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
...
...
@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
def
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
)
->
T
uple
[
L
ist
[
int
],
D
ict
[
str
,
S
et
[
int
]]]:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
)
->
t
uple
[
l
ist
[
int
],
d
ict
[
str
,
s
et
[
int
]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
seq_ids
:
L
ist
[
int
]
=
[]
request_id_seq_ids_mapping
:
D
efault
D
ict
[
str
,
S
et
[
int
]]
=
defaultdict
(
set
)
seq_ids
:
l
ist
[
int
]
=
[]
request_id_seq_ids_mapping
:
d
efault
d
ict
[
str
,
s
et
[
int
]]
=
defaultdict
(
set
)
for
sg
in
seq_group_metadata_list
:
for
seq_id
in
sg
.
seq_data
:
seq_ids
.
append
(
seq_id
)
...
...
@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states
:
torch
.
Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list
:
Optional
[
L
ist
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
l
ist
[
SequenceGroupMetadata
]]
=
None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
_seq_ids
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_seq_ids
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
):
if
self
.
seq_group_metadata_list
is
not
None
:
...
...
@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
@
property
def
seq_ids
(
self
)
->
L
ist
[
int
]:
def
seq_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
_seq_ids
def
update
(
self
,
hidden_states
:
torch
.
Tensor
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
],
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Update hidden states from target model invocation. Only used for
decode steps"""
...
...
@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
])
def
prune
(
self
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
])
->
None
:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
...
...
@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list
:
L
ist
[
Union
[
SequenceGroupMetadata
,
seq_group_metadata_list
:
l
ist
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in
:
L
ist
[
T
uple
[
int
,
blocks_to_swap_in
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out
:
L
ist
[
T
uple
[
int
,
blocks_to_swap_out
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to copy. Source to dest block.
blocks_to_copy
:
L
ist
[
T
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
blocks_to_copy
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Virtual engine ID for pipeline parallel.
virtual_engine
:
int
=
0
# The number of slots for lookahead decoding.
...
...
@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
# The step index for spec model input.
spec_step_idx
:
Optional
[
int
]
=
None
# Finished request ids since last step.
finished_requests_ids
:
L
ist
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
l
ist
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
...
...
@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
return
state
.
current_step
def
clone
(
self
,
seq_group_metadata_list
:
L
ist
[
Union
[
SequenceGroupMetadata
,
self
,
seq_group_metadata_list
:
l
ist
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
...
...
@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
assembled_seq_group
:
Optional
[
SequenceGroup
]
=
None
# seq id to a unique index inside this group
seq_id_to_index
:
D
ict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
seq_id_to_index
:
d
ict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
# seq ids to be finished
to_be_finished
:
D
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
to_be_finished
:
d
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
# seq id to finished sequences
finished_reqs
:
D
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
finished_reqs
:
d
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
streaming
:
bool
=
False
...
...
vllm/tracing.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Mapping
,
Optional
from
collections.abc
import
Mapping
from
typing
import
Optional
from
vllm.logger
import
init_logger
from
vllm.utils
import
run_once
...
...
vllm/utils.py
View file @
cf069aa8
...
...
@@ -28,12 +28,12 @@ import warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections.abc
import
Hashable
,
Iterable
,
Mapping
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
Mapping
)
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
TypeVar
,
Union
)
from
uuid
import
uuid4
import
cloudpickle
...
...
@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None],
async
def
merge_async_iterators
(
*
iterators
:
AsyncGenerator
[
T
,
None
],
)
->
AsyncGenerator
[
T
uple
[
int
,
T
],
None
]:
None
],
)
->
AsyncGenerator
[
t
uple
[
int
,
T
],
None
]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
...
...
@@ -433,7 +433,7 @@ async def merge_async_iterators(
async
def
collect_from_async_generator
(
iterator
:
AsyncGenerator
[
T
,
None
])
->
L
ist
[
T
]:
iterator
:
AsyncGenerator
[
T
,
None
])
->
l
ist
[
T
]:
"""Collect all items from an async generator into a list."""
items
=
[]
async
for
item
in
iterator
:
...
...
@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]:
return
None
def
update_environment_variables
(
envs
:
D
ict
[
str
,
str
]):
def
update_environment_variables
(
envs
:
d
ict
[
str
,
str
]):
for
k
,
v
in
envs
.
items
():
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
logger
.
warning
(
...
...
@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]):
os
.
environ
[
k
]
=
v
def
chunk_list
(
lst
:
L
ist
[
T
],
chunk_size
:
int
):
def
chunk_list
(
lst
:
l
ist
[
T
],
chunk_size
:
int
):
"""Yield successive chunk_size chunks from lst."""
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
):
yield
lst
[
i
:
i
+
chunk_size
]
...
...
@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
T
uple
[
L
ist
[
torch
.
Tensor
],
L
ist
[
torch
.
Tensor
]]:
)
->
t
uple
[
l
ist
[
torch
.
Tensor
],
l
ist
[
torch
.
Tensor
]]:
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
...
...
@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash(
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
scale
=
head_size
**-
0.5
key_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
key_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
...
...
@@ -679,7 +679,7 @@ def create_kv_caches_with_random(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
T
uple
[
L
ist
[
torch
.
Tensor
],
L
ist
[
torch
.
Tensor
]]:
)
->
t
uple
[
l
ist
[
torch
.
Tensor
],
l
ist
[
torch
.
Tensor
]]:
if
cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
...
...
@@ -693,7 +693,7 @@ def create_kv_caches_with_random(
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
key_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
...
...
@@ -708,7 +708,7 @@ def create_kv_caches_with_random(
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
...
...
@@ -754,7 +754,7 @@ class DeviceMemoryProfiler:
def
make_ndarray_with_pad
(
x
:
L
ist
[
L
ist
[
T
]],
x
:
l
ist
[
l
ist
[
T
]],
pad
:
T
,
dtype
:
npt
.
DTypeLike
,
*
,
...
...
@@ -779,7 +779,7 @@ def make_ndarray_with_pad(
def
make_tensor_with_pad
(
x
:
L
ist
[
L
ist
[
T
]],
x
:
l
ist
[
l
ist
[
T
]],
pad
:
T
,
dtype
:
torch
.
dtype
,
*
,
...
...
@@ -831,7 +831,7 @@ def is_list_of(
typ
:
Union
[
type
[
T
],
tuple
[
type
[
T
],
...]],
*
,
check
:
Literal
[
"first"
,
"all"
]
=
"first"
,
)
->
TypeIs
[
L
ist
[
T
]]:
)
->
TypeIs
[
l
ist
[
T
]]:
if
not
isinstance
(
value
,
list
):
return
False
...
...
@@ -843,8 +843,8 @@ def is_list_of(
assert_never
(
check
)
JSONTree
=
Union
[
D
ict
[
str
,
"JSONTree[T]"
],
L
ist
[
"JSONTree[T]"
],
T
uple
[
"JSONTree[T]"
,
...],
T
]
JSONTree
=
Union
[
d
ict
[
str
,
"JSONTree[T]"
],
l
ist
[
"JSONTree[T]"
],
t
uple
[
"JSONTree[T]"
,
...],
T
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
...
...
@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
return
func
(
value
)
def
flatten_2d_lists
(
lists
:
L
ist
[
L
ist
[
T
]])
->
L
ist
[
T
]:
def
flatten_2d_lists
(
lists
:
l
ist
[
l
ist
[
T
]])
->
l
ist
[
T
]:
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
...
...
@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return
value
def
_pull_args_from_config
(
self
,
args
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
_pull_args_from_config
(
self
,
args
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
...
...
@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return
args
def
_load_config_file
(
self
,
file_path
:
str
)
->
L
ist
[
str
]:
def
_load_config_file
(
self
,
file_path
:
str
)
->
l
ist
[
str
]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
...
...
@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
%s supplied"
,
extension
)
# only expecting a flat dictionary of atomic types
processed_args
:
L
ist
[
str
]
=
[]
processed_args
:
l
ist
[
str
]
=
[]
config
:
D
ict
[
str
,
Union
[
int
,
str
]]
=
{}
config
:
d
ict
[
str
,
Union
[
int
,
str
]]
=
{}
try
:
with
open
(
file_path
)
as
config_file
:
config
=
yaml
.
safe_load
(
config_file
)
...
...
@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs(
*
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
...
...
@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides(
*
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
...
...
@@ -1531,9 +1531,9 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
[
str
,
T
],
Generic
[
T
]):
def
__init__
(
self
,
factory
:
D
ict
[
str
,
Callable
[[],
T
]]):
def
__init__
(
self
,
factory
:
d
ict
[
str
,
Callable
[[],
T
]]):
self
.
_factory
=
factory
self
.
_dict
:
D
ict
[
str
,
T
]
=
{}
self
.
_dict
:
d
ict
[
str
,
T
]
=
{}
def
__getitem__
(
self
,
key
:
str
)
->
T
:
if
key
not
in
self
.
_dict
:
...
...
@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return
len
(
self
.
_factory
)
class
ClassRegistry
(
UserDict
[
T
ype
[
T
],
_V
]):
class
ClassRegistry
(
UserDict
[
t
ype
[
T
],
_V
]):
def
__getitem__
(
self
,
key
:
T
ype
[
T
])
->
_V
:
def
__getitem__
(
self
,
key
:
t
ype
[
T
])
->
_V
:
for
cls
in
key
.
mro
():
if
cls
in
self
.
data
:
return
self
.
data
[
cls
]
...
...
@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
def
weak_ref_tensors
(
tensors
:
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
],
T
uple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
],
T
uple
[
torch
.
Tensor
]]:
tensors
:
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
],
t
uple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
],
t
uple
[
torch
.
Tensor
]]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
...
...
@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
L
ist
[
str
],
mutates_args
:
l
ist
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
...
...
@@ -2177,8 +2177,8 @@ def get_mp_context():
def
bind_kv_cache
(
ctx
:
D
ict
[
str
,
Any
],
kv_cache
:
L
ist
[
L
ist
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
ctx
:
d
ict
[
str
,
Any
],
kv_cache
:
l
ist
[
l
ist
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
)
->
None
:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
...
...
@@ -2210,8 +2210,8 @@ def bind_kv_cache(
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
def
run_method
(
obj
:
Any
,
method
:
Union
[
str
,
bytes
,
Callable
],
args
:
T
uple
[
Any
],
kwargs
:
D
ict
[
str
,
Any
])
->
Any
:
def
run_method
(
obj
:
Any
,
method
:
Union
[
str
,
bytes
,
Callable
],
args
:
t
uple
[
Any
],
kwargs
:
d
ict
[
str
,
Any
])
->
Any
:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
...
...
@@ -2263,7 +2263,7 @@ def import_pynvml():
return
pynvml
def
warn_for_unimplemented_methods
(
cls
:
T
ype
[
T
])
->
T
ype
[
T
]:
def
warn_for_unimplemented_methods
(
cls
:
t
ype
[
T
])
->
t
ype
[
T
]:
"""
A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_supported_head_sizes
()
->
L
ist
[
int
]:
def
get_supported_head_sizes
()
->
l
ist
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
...
...
@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend):
return
"FLASH_ATTN_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
T
ype
[
"FlashAttentionImpl"
]:
def
get_impl_cls
()
->
t
ype
[
"FlashAttentionImpl"
]:
return
FlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
...
...
@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend):
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
T
uple
[
int
,
...]:
)
->
t
uple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
...
...
@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl):
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
None
:
...
...
@@ -381,7 +381,7 @@ def cascade_attention(
max_kv_len
:
int
,
softmax_scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
T
uple
[
int
,
int
],
sliding_window
:
t
uple
[
int
,
int
],
logits_soft_cap
:
float
,
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
cf069aa8
...
...
@@ -195,8 +195,7 @@ return curr_o @ W_O
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
)
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
...
...
@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend):
return
"TRITON_MLA_VLLM_V1"
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"AttentionMetadata"
]:
return
MLACommonMetadata
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"MLACommonMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"MLACommonMetadataBuilder"
]:
return
MLACommonMetadataBuilder
@
staticmethod
...
...
@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend):
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
)
->
T
uple
[
int
,
...]:
)
->
t
uple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
get_supported_head_sizes
()
->
L
ist
[
int
]:
def
get_supported_head_sizes
()
->
l
ist
[
int
]:
return
[
576
]
@
staticmethod
...
...
@@ -317,8 +316,8 @@ class MLACommonMetadata:
has_context
:
bool
=
False
context_chunk_cu_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_starts
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_seq_tot
:
Optional
[
L
ist
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
L
ist
[
int
]]
=
None
context_chunk_seq_tot
:
Optional
[
l
ist
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
l
ist
[
int
]]
=
None
chunked_prefill_workspace
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
...
...
@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
...
...
@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
#
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
T
uple
[
T
uple
[
int
,
int
],
T
uple
[
int
,
int
]]:
t
uple
[
t
uple
[
int
,
int
],
t
uple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend):
return
"FLASHMLA_VLLM_V1"
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"FlashMLAMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"FlashMLAMetadata"
]:
return
FlashMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"FlashMLAMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"FlashMLAMetadataBuilder"
]:
return
FlashMLAMetadataBuilder
@
staticmethod
def
get_impl_cls
()
->
T
ype
[
"FlashMLAImpl"
]:
def
get_impl_cls
()
->
t
ype
[
"FlashMLAImpl"
]:
return
FlashMLAImpl
@
dataclass
class
FlashMLAMetadata
(
MLACommonMetadata
):
decode_tile_scheduler_metadata
:
Optional
[
T
uple
[
torch
.
Tensor
,
decode_tile_scheduler_metadata
:
Optional
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
decode_num_splits
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
...
...
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment