Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
225 additions
and
220 deletions
+225
-220
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+4
-3
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
.../entrypoints/openai/tool_parsers/internlm2_tool_parser.py
+3
-2
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+7
-6
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+3
-2
vllm/entrypoints/openai/tool_parsers/utils.py
vllm/entrypoints/openai/tool_parsers/utils.py
+3
-3
vllm/entrypoints/score_utils.py
vllm/entrypoints/score_utils.py
+7
-7
vllm/envs.py
vllm/envs.py
+4
-4
vllm/forward_context.py
vllm/forward_context.py
+3
-3
vllm/logger.py
vllm/logger.py
+1
-1
vllm/logits_process.py
vllm/logits_process.py
+8
-8
vllm/outputs.py
vllm/outputs.py
+12
-12
vllm/sampling_params.py
vllm/sampling_params.py
+26
-27
vllm/sequence.py
vllm/sequence.py
+66
-66
vllm/tracing.py
vllm/tracing.py
+2
-1
vllm/utils.py
vllm/utils.py
+38
-38
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+9
-9
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+10
-11
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-7
No files found.
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
View file @
cf069aa8
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
json
import
json
import
re
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser):
...
@@ -33,9 +34,9 @@ class Hermes2ProToolParser(ToolParser):
self
.
model_tokenizer
=
self
.
model_tokenizer
.
tokenizer
self
.
model_tokenizer
=
self
.
model_tokenizer
.
tokenizer
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_start_token
:
str
=
"<tool_call>"
...
...
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
from
typing
import
Dict
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser):
...
@@ -90,7 +91,7 @@ class Internlm2ToolParser(ToolParser):
# tool calls are generated in an object in inernlm2
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
# it's not support parallel tool calls
try
:
try
:
tool_call_arr
:
D
ict
=
partial_json_parser
.
loads
(
tool_call_arr
:
d
ict
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
...
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
View file @
cf069aa8
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
json
import
json
import
re
import
re
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser):
...
@@ -35,9 +36,9 @@ class JambaToolParser(ToolParser):
)
)
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
tool_calls_start_token
:
str
=
"<tool_calls>"
self
.
tool_calls_start_token
:
str
=
"<tool_calls>"
...
@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser):
...
@@ -157,7 +158,7 @@ class JambaToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
# parsing on the entire array
try
:
try
:
tool_call_arr
:
L
ist
[
D
ict
]
=
partial_json_parser
.
loads
(
tool_call_arr
:
l
ist
[
d
ict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser):
...
@@ -165,7 +166,7 @@ class JambaToolParser(ToolParser):
# select as the current tool call the one we're on the state at
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
View file @
cf069aa8
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
import
json
import
json
import
re
import
re
from
collections.abc
import
Sequence
from
json
import
JSONDecoder
from
json
import
JSONDecoder
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -40,10 +41,10 @@ class Llama3JsonToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# initialize properties used for state when parsing tool calls in
# streaming mode
# streaming mode
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"<|python_tag|>"
self
.
bot_token
=
"<|python_tag|>"
self
.
bot_token_id
=
tokenizer
.
encode
(
self
.
bot_token
,
self
.
bot_token_id
=
tokenizer
.
encode
(
self
.
bot_token
,
...
@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -78,7 +79,7 @@ class Llama3JsonToolParser(ToolParser):
start_idx
+=
end_idx
+
len
(
'; '
)
start_idx
+=
end_idx
+
len
(
'; '
)
function_call_arr
.
append
(
obj
)
function_call_arr
.
append
(
obj
)
tool_calls
:
L
ist
[
ToolCall
]
=
[
tool_calls
:
l
ist
[
ToolCall
]
=
[
ToolCall
(
ToolCall
(
type
=
"function"
,
type
=
"function"
,
function
=
FunctionCall
(
function
=
FunctionCall
(
...
@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -152,7 +153,7 @@ class Llama3JsonToolParser(ToolParser):
return
None
return
None
# select as the current tool call the one we're on the state at
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
cf069aa8
...
@@ -2,9 +2,10 @@
...
@@ -2,9 +2,10 @@
import
json
import
json
import
re
import
re
from
collections.abc
import
Sequence
from
random
import
choices
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
string
import
ascii_letters
,
digits
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser):
...
@@ -56,10 +57,10 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# initialize properties used for state when parsing tool calls in
# streaming mode
# streaming mode
self
.
prev_tool_call_arr
:
L
ist
[
D
ict
]
=
[]
self
.
prev_tool_call_arr
:
l
ist
[
d
ict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
current_tool_name_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
L
ist
[
str
]
=
[
self
.
streamed_args_for_tool
:
l
ist
[
str
]
=
[
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
...
@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser):
...
@@ -104,7 +105,7 @@ class MistralToolParser(ToolParser):
function_call_arr
=
json
.
loads
(
raw_tool_call
)
function_call_arr
=
json
.
loads
(
raw_tool_call
)
# Tool Call
# Tool Call
tool_calls
:
L
ist
[
MistralToolCall
]
=
[
tool_calls
:
l
ist
[
MistralToolCall
]
=
[
MistralToolCall
(
MistralToolCall
(
type
=
"function"
,
type
=
"function"
,
function
=
FunctionCall
(
function
=
FunctionCall
(
...
@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser):
...
@@ -172,7 +173,7 @@ class MistralToolParser(ToolParser):
# tool calls are generated in an array, so do partial JSON
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
# parsing on the entire array
try
:
try
:
tool_call_arr
:
L
ist
[
D
ict
]
=
partial_json_parser
.
loads
(
tool_call_arr
:
l
ist
[
d
ict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
)
parsable_arr
,
flags
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
logger
.
debug
(
'not enough tokens to parse into JSON yet'
)
...
@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser):
...
@@ -180,7 +181,7 @@ class MistralToolParser(ToolParser):
# select as the current tool call the one we're on the state at
# select as the current tool call the one we're on the state at
current_tool_call
:
D
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
current_tool_call
:
d
ict
=
tool_call_arr
[
self
.
current_tool_id
]
\
if
len
(
tool_call_arr
)
>
0
else
{}
if
len
(
tool_call_arr
)
>
0
else
{}
# case -- if no tokens have been streamed for the tool, e.g.
# case -- if no tokens have been streamed for the tool, e.g.
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
cf069aa8
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
import
ast
import
ast
import
json
import
json
import
re
import
re
from
typing
import
Any
,
Sequence
,
Tuple
,
Union
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
...
@@ -204,7 +205,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments
=
json
.
dumps
(
arguments
)))
arguments
=
json
.
dumps
(
arguments
)))
def
_make_valid_python
(
text
:
str
)
->
Union
[
T
uple
[
str
,
str
],
None
]:
def
_make_valid_python
(
text
:
str
)
->
Union
[
t
uple
[
str
,
str
],
None
]:
bracket_stack
=
[]
bracket_stack
=
[]
for
index
,
char
in
enumerate
(
text
):
for
index
,
char
in
enumerate
(
text
):
if
char
in
{
"["
,
"("
,
"{"
}:
if
char
in
{
"["
,
"("
,
"{"
}:
...
...
vllm/entrypoints/openai/tool_parsers/utils.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
json
import
json
from
json
import
JSONDecodeError
,
JSONDecoder
from
json
import
JSONDecodeError
,
JSONDecoder
from
typing
import
Any
,
List
,
Tuple
from
typing
import
Any
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
...
@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
...
@@ -82,7 +82,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
return
diff
return
diff
def
find_all_indices
(
string
:
str
,
substring
:
str
)
->
L
ist
[
int
]:
def
find_all_indices
(
string
:
str
,
substring
:
str
)
->
l
ist
[
int
]:
"""
"""
Find all (starting) indices of a substring in a given string. Useful for
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
tool call extraction
...
@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]:
...
@@ -99,7 +99,7 @@ def find_all_indices(string: str, substring: str) -> List[int]:
# partial_json_parser doesn't support extra data and
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
# JSONDecorder.raw_decode doesn't support partial JSON
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
T
uple
[
Any
,
int
]:
def
partial_json_loads
(
input_str
:
str
,
flags
:
Allow
)
->
t
uple
[
Any
,
int
]:
try
:
try
:
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
return
(
partial_json_parser
.
loads
(
input_str
,
flags
),
len
(
input_str
))
except
JSONDecodeError
as
e
:
except
JSONDecodeError
as
e
:
...
...
vllm/entrypoints/score_utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Union
from
typing
import
Union
from
torch.nn
import
CosineSimilarity
from
torch.nn
import
CosineSimilarity
...
@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
...
@@ -10,12 +10,12 @@ from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
def
_cosine_similarity
(
def
_cosine_similarity
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
embed_1
:
L
ist
[
PoolingRequestOutput
],
embed_1
:
l
ist
[
PoolingRequestOutput
],
embed_2
:
L
ist
[
PoolingRequestOutput
],
embed_2
:
l
ist
[
PoolingRequestOutput
],
)
->
L
ist
[
PoolingRequestOutput
]:
)
->
l
ist
[
PoolingRequestOutput
]:
scorer
=
CosineSimilarity
(
0
)
scorer
=
CosineSimilarity
(
0
)
scores
:
Union
[
L
ist
[
PoolingRequestOutput
]]
=
[]
scores
:
Union
[
l
ist
[
PoolingRequestOutput
]]
=
[]
for
emb_1
,
emb_2
in
zip
(
embed_1
,
embed_2
):
for
emb_1
,
emb_2
in
zip
(
embed_1
,
embed_2
):
pair_score
=
scorer
(
emb_1
.
outputs
.
data
,
emb_2
.
outputs
.
data
)
pair_score
=
scorer
(
emb_1
.
outputs
.
data
,
emb_2
.
outputs
.
data
)
...
@@ -38,8 +38,8 @@ def _cosine_similarity(
...
@@ -38,8 +38,8 @@ def _cosine_similarity(
def
_validate_score_input_lens
(
def
_validate_score_input_lens
(
texts_1
:
Union
[
L
ist
[
str
],
L
ist
[
dict
]],
texts_1
:
Union
[
l
ist
[
str
],
l
ist
[
dict
]],
texts_2
:
Union
[
L
ist
[
str
],
L
ist
[
dict
]],
texts_2
:
Union
[
l
ist
[
str
],
l
ist
[
dict
]],
):
):
if
len
(
texts_1
)
>
1
and
len
(
texts_1
)
!=
len
(
texts_2
):
if
len
(
texts_1
)
>
1
and
len
(
texts_1
)
!=
len
(
texts_2
):
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
...
...
vllm/envs.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
os
import
os
import
tempfile
import
tempfile
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
VLLM_HOST_IP
:
str
=
""
VLLM_HOST_IP
:
str
=
""
...
@@ -67,12 +67,12 @@ if TYPE_CHECKING:
...
@@ -67,12 +67,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_TIMEOUT
:
int
=
10000
# ms
VLLM_RPC_TIMEOUT
:
int
=
10000
# ms
VLLM_PLUGINS
:
Optional
[
L
ist
[
str
]]
=
None
VLLM_PLUGINS
:
Optional
[
l
ist
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
L
ist
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
l
ist
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_USE_V1
:
bool
=
False
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
...
@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
...
@@ -123,7 +123,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# begin-env-vars-definition
# begin-env-vars-definition
environment_variables
:
D
ict
[
str
,
Callable
[[],
Any
]]
=
{
environment_variables
:
d
ict
[
str
,
Callable
[[],
Any
]]
=
{
# ================== Installation Time Env Vars ==================
# ================== Installation Time Env Vars ==================
...
...
vllm/forward_context.py
View file @
cf069aa8
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list)
...
@@ -28,13 +28,13 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
@
dataclass
class
ForwardContext
:
class
ForwardContext
:
# copy from vllm_config.compilation_config.static_forward_context
# copy from vllm_config.compilation_config.static_forward_context
attn_layers
:
D
ict
[
str
,
Any
]
attn_layers
:
d
ict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
# TODO: extend to support per-layer dynamic forward context
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
virtual_engine
:
int
# set dynamically for each forward pass
num_tokens_across_dp
:
Optional
[
num_tokens_across_dp
:
Optional
[
L
ist
[
int
]]
=
None
# set dynamically for each forward pass
l
ist
[
int
]]
=
None
# set dynamically for each forward pass
_forward_context
:
Optional
[
ForwardContext
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
...
vllm/logger.py
View file @
cf069aa8
...
@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None:
...
@@ -109,7 +109,7 @@ def _configure_vllm_root_logger() -> None:
custom_config
=
json
.
loads
(
file
.
read
())
custom_config
=
json
.
loads
(
file
.
read
())
if
not
isinstance
(
custom_config
,
dict
):
if
not
isinstance
(
custom_config
,
dict
):
raise
ValueError
(
"Invalid logging config. Expected
D
ict, got %s."
,
raise
ValueError
(
"Invalid logging config. Expected
d
ict, got %s."
,
type
(
custom_config
).
__name__
)
type
(
custom_config
).
__name__
)
logging_config
=
custom_config
logging_config
=
custom_config
...
...
vllm/logits_process.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Tuple
,
Union
from
typing
import
Callable
,
Union
import
torch
import
torch
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
LogitsProcessor
=
Union
[
Callable
[[
L
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
],
LogitsProcessor
=
Union
[
Callable
[[
l
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
],
Callable
[[
L
ist
[
int
],
L
ist
[
int
],
torch
.
Tensor
],
Callable
[[
l
ist
[
int
],
l
ist
[
int
],
torch
.
Tensor
],
torch
.
Tensor
]]
torch
.
Tensor
]]
"""LogitsProcessor is a function that takes a list
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
of previously generated tokens, the logits tensor
...
@@ -17,9 +17,9 @@ to sample from."""
...
@@ -17,9 +17,9 @@ to sample from."""
def
get_bad_words_logits_processors
(
def
get_bad_words_logits_processors
(
bad_words
:
L
ist
[
str
],
bad_words
:
l
ist
[
str
],
tokenizer
:
AnyTokenizer
)
->
L
ist
[
LogitsProcessor
]:
tokenizer
:
AnyTokenizer
)
->
l
ist
[
LogitsProcessor
]:
bad_words_ids
:
L
ist
[
L
ist
[
int
]]
=
list
()
bad_words_ids
:
l
ist
[
l
ist
[
int
]]
=
list
()
for
bad_word
in
bad_words
:
for
bad_word
in
bad_words
:
# To prohibit words both at the beginning
# To prohibit words both at the beginning
...
@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
...
@@ -51,13 +51,13 @@ class NoBadWordsLogitsProcessor:
_SMALLEST_LOGIT
=
float
(
"-inf"
)
_SMALLEST_LOGIT
=
float
(
"-inf"
)
_NEUTRAL_LOGIT
=
0.0
_NEUTRAL_LOGIT
=
0.0
def
__init__
(
self
,
bad_words_ids
:
L
ist
[
L
ist
[
int
]]):
def
__init__
(
self
,
bad_words_ids
:
l
ist
[
l
ist
[
int
]]):
self
.
bad_words_ids
=
bad_words_ids
self
.
bad_words_ids
=
bad_words_ids
self
.
word_bias
:
torch
.
FloatTensor
=
None
self
.
word_bias
:
torch
.
FloatTensor
=
None
def
__call__
(
def
__call__
(
self
,
self
,
past_tokens_ids
:
Union
[
L
ist
[
int
],
T
uple
[
int
]],
past_tokens_ids
:
Union
[
l
ist
[
int
],
t
uple
[
int
]],
logits
:
torch
.
FloatTensor
,
logits
:
torch
.
FloatTensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
word_bias
is
None
:
if
self
.
word_bias
is
None
:
...
...
vllm/outputs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
time
import
time
from
collections.abc
import
MutableSequence
from
collections.abc
import
Sequence
as
GenericSequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Generic
,
List
,
MutableSequence
,
Optional
from
typing
import
Generic
,
Optional
,
Union
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
import
torch
import
torch
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
...
@@ -109,14 +109,14 @@ class RequestOutput:
...
@@ -109,14 +109,14 @@ class RequestOutput:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
L
ist
[
int
]],
prompt_token_ids
:
Optional
[
l
ist
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
L
ist
[
CompletionOutput
],
outputs
:
l
ist
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
metrics
:
Optional
[
RequestMetrics
]
=
None
,
metrics
:
Optional
[
RequestMetrics
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
encoder_prompt_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
encoder_prompt_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
num_cached_tokens
:
Optional
[
int
]
=
None
,
num_cached_tokens
:
Optional
[
int
]
=
None
,
*
,
*
,
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
,
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
,
...
@@ -139,9 +139,9 @@ class RequestOutput:
...
@@ -139,9 +139,9 @@ class RequestOutput:
cls
,
cls
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
L
ist
[
int
]],
prompt_token_ids
:
Optional
[
l
ist
[
int
]],
text
:
str
,
text
:
str
,
token_ids
:
L
ist
[
int
],
token_ids
:
l
ist
[
int
],
logprobs
:
Optional
[
SampleLogprobs
],
logprobs
:
Optional
[
SampleLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
cumulative_logprob
:
Optional
[
float
],
cumulative_logprob
:
Optional
[
float
],
...
@@ -189,7 +189,7 @@ class RequestOutput:
...
@@ -189,7 +189,7 @@ class RequestOutput:
@
classmethod
@
classmethod
def
from_seq_group
(
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
use_cache
:
bool
,
cls
,
seq_group
:
SequenceGroup
,
use_cache
:
bool
,
seq_id_to_seq_group
:
D
ict
[
str
,
SequenceGroupBase
]
seq_id_to_seq_group
:
d
ict
[
str
,
SequenceGroupBase
]
)
->
Optional
[
"RequestOutput"
]:
)
->
Optional
[
"RequestOutput"
]:
finished
=
seq_group
.
is_finished
()
finished
=
seq_group
.
is_finished
()
...
@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]):
...
@@ -363,12 +363,12 @@ class PoolingRequestOutput(Generic[_O]):
Args:
Args:
request_id (str): A unique identifier for the pooling request.
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (
L
ist[int]): A list of token IDs used in the prompt.
prompt_token_ids (
l
ist[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed.
finished (bool): A flag indicating whether the pooling is completed.
"""
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
_O
,
def
__init__
(
self
,
request_id
:
str
,
outputs
:
_O
,
prompt_token_ids
:
L
ist
[
int
],
finished
:
bool
):
prompt_token_ids
:
l
ist
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
self
.
finished
=
finished
self
.
finished
=
finished
...
@@ -407,7 +407,7 @@ class RequestOutputFactory:
...
@@ -407,7 +407,7 @@ class RequestOutputFactory:
@
staticmethod
@
staticmethod
def
create
(
seq_group
:
SequenceGroup
,
def
create
(
seq_group
:
SequenceGroup
,
seq_id_to_seq_group
:
D
ict
[
str
,
SequenceGroupBase
],
seq_id_to_seq_group
:
d
ict
[
str
,
SequenceGroupBase
],
use_cache
:
bool
=
False
):
use_cache
:
bool
=
False
):
if
seq_group
.
pooled_data
is
not
None
:
if
seq_group
.
pooled_data
is
not
None
:
return
PoolingRequestOutput
.
from_seq_group
(
seq_group
)
return
PoolingRequestOutput
.
from_seq_group
(
seq_group
)
...
...
vllm/sampling_params.py
View file @
cf069aa8
...
@@ -4,11 +4,10 @@ import copy
...
@@ -4,11 +4,10 @@ import copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
IntEnum
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
An
y
,
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
An
notated
,
Any
,
Optional
,
Union
import
msgspec
import
msgspec
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
typing_extensions
import
Annotated
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logits_process
import
LogitsProcessor
from
vllm.logits_process
import
LogitsProcessor
...
@@ -29,9 +28,9 @@ class SamplingType(IntEnum):
...
@@ -29,9 +28,9 @@ class SamplingType(IntEnum):
@
dataclass
@
dataclass
class
GuidedDecodingParams
:
class
GuidedDecodingParams
:
"""One of these fields will be used to build a logit processor."""
"""One of these fields will be used to build a logit processor."""
json
:
Optional
[
Union
[
str
,
D
ict
]]
=
None
json
:
Optional
[
Union
[
str
,
d
ict
]]
=
None
regex
:
Optional
[
str
]
=
None
regex
:
Optional
[
str
]
=
None
choice
:
Optional
[
L
ist
[
str
]]
=
None
choice
:
Optional
[
l
ist
[
str
]]
=
None
grammar
:
Optional
[
str
]
=
None
grammar
:
Optional
[
str
]
=
None
json_object
:
Optional
[
bool
]
=
None
json_object
:
Optional
[
bool
]
=
None
"""These are other options that can be set"""
"""These are other options that can be set"""
...
@@ -40,9 +39,9 @@ class GuidedDecodingParams:
...
@@ -40,9 +39,9 @@ class GuidedDecodingParams:
@
staticmethod
@
staticmethod
def
from_optional
(
def
from_optional
(
json
:
Optional
[
Union
[
D
ict
,
BaseModel
,
str
]]
=
None
,
json
:
Optional
[
Union
[
d
ict
,
BaseModel
,
str
]]
=
None
,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
choice
:
Optional
[
L
ist
[
str
]]
=
None
,
choice
:
Optional
[
l
ist
[
str
]]
=
None
,
grammar
:
Optional
[
str
]
=
None
,
grammar
:
Optional
[
str
]
=
None
,
json_object
:
Optional
[
bool
]
=
None
,
json_object
:
Optional
[
bool
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
...
@@ -72,7 +71,7 @@ class GuidedDecodingParams:
...
@@ -72,7 +71,7 @@ class GuidedDecodingParams:
"""
"""
return
(
self
.
backend
or
""
).
split
(
":"
)[
0
]
return
(
self
.
backend
or
""
).
split
(
":"
)[
0
]
def
backend_options
(
self
)
->
L
ist
[
str
]:
def
backend_options
(
self
)
->
l
ist
[
str
]:
"""Return the backend options as a list of strings."""
"""Return the backend options as a list of strings."""
if
not
self
.
backend
or
":"
not
in
self
.
backend
:
if
not
self
.
backend
or
":"
not
in
self
.
backend
:
return
[]
return
[]
...
@@ -144,12 +143,12 @@ class SamplingParams(
...
@@ -144,12 +143,12 @@ class SamplingParams(
considered, relative to the probability of the most likely token.
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
seed: Random seed to use for the generation.
stop:
L
ist of strings that stop the generation when they are generated.
stop:
l
ist of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
The returned output will not contain the stop strings.
stop_token_ids:
L
ist of tokens that stop the generation when they are
stop_token_ids:
l
ist of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
generated. The returned output will contain the stop tokens unless
the stop tokens are special tokens.
the stop tokens are special tokens.
bad_words:
L
ist of words that are not allowed to be generated.
bad_words:
l
ist of words that are not allowed to be generated.
More precisely, only the last token of a corresponding
More precisely, only the last token of a corresponding
token sequence is not allowed when the next generated token
token sequence is not allowed when the next generated token
can complete the sequence.
can complete the sequence.
...
@@ -172,7 +171,7 @@ class SamplingParams(
...
@@ -172,7 +171,7 @@ class SamplingParams(
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
tokens in the output. Defaults to True.
logits_processors:
L
ist of functions that modify logits based on
logits_processors:
l
ist of functions that modify logits based on
previously generated tokens, and optionally prompt tokens as
previously generated tokens, and optionally prompt tokens as
a first argument.
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
truncate_prompt_tokens: If set to an integer k, will use only the last k
...
@@ -198,9 +197,9 @@ class SamplingParams(
...
@@ -198,9 +197,9 @@ class SamplingParams(
top_k
:
int
=
-
1
top_k
:
int
=
-
1
min_p
:
float
=
0.0
min_p
:
float
=
0.0
seed
:
Optional
[
int
]
=
None
seed
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
max_tokens
:
Optional
[
int
]
=
16
max_tokens
:
Optional
[
int
]
=
16
min_tokens
:
int
=
0
min_tokens
:
int
=
0
...
@@ -212,8 +211,8 @@ class SamplingParams(
...
@@ -212,8 +211,8 @@ class SamplingParams(
detokenize
:
bool
=
True
detokenize
:
bool
=
True
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
# Optional[
L
ist[LogitsProcessor]] type. We use Any here because
# Optional[
l
ist[LogitsProcessor]] type. We use Any here because
# Optional[
L
ist[LogitsProcessor]] type is not supported by msgspec.
# Optional[
l
ist[LogitsProcessor]] type is not supported by msgspec.
logits_processors
:
Optional
[
Any
]
=
None
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
...
@@ -222,12 +221,12 @@ class SamplingParams(
...
@@ -222,12 +221,12 @@ class SamplingParams(
# The below fields are not supposed to be used as an input.
# The below fields are not supposed to be used as an input.
# They are set in post_init.
# They are set in post_init.
output_text_buffer_length
:
int
=
0
output_text_buffer_length
:
int
=
0
_all_stop_token_ids
:
S
et
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
_all_stop_token_ids
:
s
et
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
# Fields used to construct logits processors
# Fields used to construct logits processors
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
logit_bias
:
Optional
[
D
ict
[
int
,
float
]]
=
None
logit_bias
:
Optional
[
d
ict
[
int
,
float
]]
=
None
allowed_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
allowed_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
@
staticmethod
@
staticmethod
def
from_optional
(
def
from_optional
(
...
@@ -241,9 +240,9 @@ class SamplingParams(
...
@@ -241,9 +240,9 @@ class SamplingParams(
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
,
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
bad_words
:
Optional
[
L
ist
[
str
]]
=
None
,
bad_words
:
Optional
[
l
ist
[
str
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
Optional
[
int
]
=
16
,
max_tokens
:
Optional
[
int
]
=
16
,
...
@@ -253,13 +252,13 @@ class SamplingParams(
...
@@ -253,13 +252,13 @@ class SamplingParams(
detokenize
:
bool
=
True
,
detokenize
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
L
ist
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
l
ist
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
,
guided_decoding
:
Optional
[
GuidedDecodingParams
]
=
None
,
logit_bias
:
Optional
[
Union
[
D
ict
[
int
,
float
],
D
ict
[
str
,
float
]]]
=
None
,
logit_bias
:
Optional
[
Union
[
d
ict
[
int
,
float
],
d
ict
[
str
,
float
]]]
=
None
,
allowed_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
allowed_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
)
->
"SamplingParams"
:
)
->
"SamplingParams"
:
if
logit_bias
is
not
None
:
if
logit_bias
is
not
None
:
# Convert token_id to integer
# Convert token_id to integer
...
@@ -435,7 +434,7 @@ class SamplingParams(
...
@@ -435,7 +434,7 @@ class SamplingParams(
def
update_from_generation_config
(
def
update_from_generation_config
(
self
,
self
,
generation_config
:
D
ict
[
str
,
Any
],
generation_config
:
d
ict
[
str
,
Any
],
model_eos_token_id
:
Optional
[
int
]
=
None
)
->
None
:
model_eos_token_id
:
Optional
[
int
]
=
None
)
->
None
:
"""Update if there are non-default values from generation_config"""
"""Update if there are non-default values from generation_config"""
...
@@ -468,7 +467,7 @@ class SamplingParams(
...
@@ -468,7 +467,7 @@ class SamplingParams(
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
@
property
@
property
def
all_stop_token_ids
(
self
)
->
S
et
[
int
]:
def
all_stop_token_ids
(
self
)
->
s
et
[
int
]:
return
self
.
_all_stop_token_ids
return
self
.
_all_stop_token_ids
def
clone
(
self
)
->
"SamplingParams"
:
def
clone
(
self
)
->
"SamplingParams"
:
...
...
vllm/sequence.py
View file @
cf069aa8
...
@@ -5,11 +5,11 @@ import enum
...
@@ -5,11 +5,11 @@ import enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
collections.abc
import
Sequence
as
GenericSequence
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -50,9 +50,9 @@ class Logprob:
...
@@ -50,9 +50,9 @@ class Logprob:
# {token_id -> logprob} per each sequence group. None if the corresponding
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
# sequence group doesn't require prompt logprob.
PromptLogprobs
=
L
ist
[
Optional
[
D
ict
[
int
,
Logprob
]]]
PromptLogprobs
=
l
ist
[
Optional
[
d
ict
[
int
,
Logprob
]]]
# {token_id -> logprob} for each sequence group.
# {token_id -> logprob} for each sequence group.
SampleLogprobs
=
L
ist
[
D
ict
[
int
,
Logprob
]]
SampleLogprobs
=
l
ist
[
d
ict
[
int
,
Logprob
]]
class
SequenceStatus
(
enum
.
IntEnum
):
class
SequenceStatus
(
enum
.
IntEnum
):
...
@@ -129,7 +129,7 @@ class SequenceDataDelta(
...
@@ -129,7 +129,7 @@ class SequenceDataDelta(
omit_defaults
=
True
):
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
# A new token to be appended to existing SequenceData.
new_output_token_ids
:
L
ist
[
int
]
new_output_token_ids
:
l
ist
[
int
]
# Overwriting existing `cumulative_logprob`
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob
:
float
new_cumulative_logprob
:
float
# Overwriting existing `num_computed_tokens`.
# Overwriting existing `num_computed_tokens`.
...
@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
...
@@ -152,7 +152,7 @@ class SequenceData(msgspec.Struct,
output_token_ids: The token IDs of the output.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
"""
# NOTE: we cannot use Union[
L
ist, array] because msgspec cannot support
# NOTE: we cannot use Union[
l
ist, array] because msgspec cannot support
# union of 2 list types.
# union of 2 list types.
_prompt_token_ids
:
array
_prompt_token_ids
:
array
_output_token_ids
:
array
=
msgspec
.
field
(
_output_token_ids
:
array
=
msgspec
.
field
(
...
@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
...
@@ -160,25 +160,25 @@ class SequenceData(msgspec.Struct,
### The below fields should not be passed as an argument ###
### The below fields should not be passed as an argument ###
_cumulative_logprob
:
float
=
0.0
_cumulative_logprob
:
float
=
0.0
_prompt_token_ids_tuple
:
T
uple
[
int
,
_prompt_token_ids_tuple
:
t
uple
[
int
,
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
_num_computed_tokens
:
int
=
0
_num_computed_tokens
:
int
=
0
# The number of tokens with prefix cache hit.
# The number of tokens with prefix cache hit.
_num_cached_tokens
:
int
=
0
_num_cached_tokens
:
int
=
0
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
_cached_all_token_ids
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_cached_all_token_ids
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
# It is used to get delta input. It is reset when `get_delta_and_reset`
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
# is called.
_new_appended_tokens
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_new_appended_tokens
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
# It is used to compute mrope_position_ids.
# It is used to compute mrope_position_ids.
_mrope_position_delta
:
Optional
[
int
]
=
None
_mrope_position_delta
:
Optional
[
int
]
=
None
@
staticmethod
@
staticmethod
def
from_prompt_token_counts
(
def
from_prompt_token_counts
(
*
token_counts
:
T
uple
[
int
,
int
])
->
"SequenceData"
:
*
token_counts
:
t
uple
[
int
,
int
])
->
"SequenceData"
:
"""
"""
Construct a :class:`SequenceData` instance by concatenating
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
prompt token sequences.
...
@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
...
@@ -220,14 +220,14 @@ class SequenceData(msgspec.Struct,
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
self
.
_prompt_token_ids_tuple
:
T
uple
[
int
,
...]
=
tuple
(
self
.
_prompt_token_ids_tuple
:
t
uple
[
int
,
...]
=
tuple
(
self
.
_prompt_token_ids
)
self
.
_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
def
_update_cached_all_tokens
(
self
):
assert
isinstance
(
self
.
_prompt_token_ids
,
array
)
assert
isinstance
(
self
.
_prompt_token_ids
,
array
)
assert
isinstance
(
self
.
_output_token_ids
,
array
)
assert
isinstance
(
self
.
_output_token_ids
,
array
)
self
.
_cached_all_token_ids
:
L
ist
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_cached_all_token_ids
:
l
ist
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
self
.
_output_token_ids
)
@
property
@
property
...
@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
...
@@ -235,7 +235,7 @@ class SequenceData(msgspec.Struct,
return
self
.
_cumulative_logprob
return
self
.
_cumulative_logprob
@
property
@
property
def
prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
_prompt_token_ids_tuple
return
self
.
_prompt_token_ids_tuple
@
prompt_token_ids
.
setter
@
prompt_token_ids
.
setter
...
@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
...
@@ -252,7 +252,7 @@ class SequenceData(msgspec.Struct,
return
self
.
_prompt_token_ids
return
self
.
_prompt_token_ids
@
property
@
property
def
output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
tuple
(
self
.
_output_token_ids
)
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
@
output_token_ids
.
setter
...
@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
...
@@ -295,12 +295,12 @@ class SequenceData(msgspec.Struct,
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
return
len
(
self
.
_output_token_ids
)
def
get_token_ids
(
self
)
->
L
ist
[
int
]:
def
get_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
_cached_all_token_ids
return
self
.
_cached_all_token_ids
def
get_prefix_token_ids
(
def
get_prefix_token_ids
(
self
,
num_tokens
:
int
self
,
num_tokens
:
int
)
->
T
uple
[
T
uple
[
int
,
...],
Optional
[
T
uple
[
int
,
...]]]:
)
->
t
uple
[
t
uple
[
int
,
...],
Optional
[
t
uple
[
int
,
...]]]:
"""Get prefix tokens, and make the return value hashable"""
"""Get prefix tokens, and make the return value hashable"""
prompt_length
=
self
.
get_prompt_len
()
prompt_length
=
self
.
get_prompt_len
()
if
num_tokens
>
prompt_length
:
if
num_tokens
>
prompt_length
:
...
@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
...
@@ -351,10 +351,10 @@ class SequenceData(msgspec.Struct,
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
get_prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
prompt_token_ids
return
self
.
prompt_token_ids
def
get_output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
output_token_ids
return
self
.
output_token_ids
def
get_delta_and_reset
(
self
)
->
SequenceDataDelta
:
def
get_delta_and_reset
(
self
)
->
SequenceDataDelta
:
...
@@ -432,7 +432,7 @@ class Sequence:
...
@@ -432,7 +432,7 @@ class Sequence:
self
.
prefix_offset
=
0
self
.
prefix_offset
=
0
self
.
read_offset
=
0
self
.
read_offset
=
0
# Input + output tokens
# Input + output tokens
self
.
tokens
:
Optional
[
L
ist
[
str
]]
=
None
self
.
tokens
:
Optional
[
l
ist
[
str
]]
=
None
@
property
@
property
def
n_blocks
(
self
)
->
int
:
def
n_blocks
(
self
)
->
int
:
...
@@ -443,7 +443,7 @@ class Sequence:
...
@@ -443,7 +443,7 @@ class Sequence:
return
self
.
inputs
.
prompt
return
self
.
inputs
.
prompt
@
property
@
property
def
prompt_token_ids
(
self
)
->
L
ist
[
int
]:
def
prompt_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
inputs
.
prompt_token_ids
return
self
.
inputs
.
prompt_token_ids
@
property
@
property
...
@@ -451,7 +451,7 @@ class Sequence:
...
@@ -451,7 +451,7 @@ class Sequence:
return
self
.
inputs
.
prompt_embeds
return
self
.
inputs
.
prompt_embeds
@
property
@
property
def
token_type_ids
(
self
)
->
L
ist
[
int
]:
def
token_type_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
inputs
.
token_type_ids
return
self
.
inputs
.
token_type_ids
@
property
@
property
...
@@ -463,7 +463,7 @@ class Sequence:
...
@@ -463,7 +463,7 @@ class Sequence:
return
self
.
inputs
.
multi_modal_placeholders
return
self
.
inputs
.
multi_modal_placeholders
@
property
@
property
def
mm_processor_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mm_processor_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
return
self
.
inputs
.
mm_processor_kwargs
return
self
.
inputs
.
mm_processor_kwargs
@
property
@
property
...
@@ -548,7 +548,7 @@ class Sequence:
...
@@ -548,7 +548,7 @@ class Sequence:
"""Reset the sequence states for recomputation."""
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_state_for_recompute
()
self
.
data
.
reset_state_for_recompute
()
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
D
ict
[
int
,
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
d
ict
[
int
,
Logprob
])
->
None
:
Logprob
])
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
...
@@ -563,16 +563,16 @@ class Sequence:
...
@@ -563,16 +563,16 @@ class Sequence:
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
return
self
.
data
.
get_output_len
()
def
get_token_ids
(
self
)
->
L
ist
[
int
]:
def
get_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
data
.
get_token_ids
()
return
self
.
data
.
get_token_ids
()
def
get_prompt_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
data
.
get_prompt_token_ids
()
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
T
uple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
t
uple
[
int
,
...]:
return
self
.
data
.
get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
def
get_cumulative_logprob
(
self
)
->
float
:
...
@@ -644,7 +644,7 @@ class SequenceGroup:
...
@@ -644,7 +644,7 @@ class SequenceGroup:
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
seqs
:
L
ist
[
Sequence
],
seqs
:
l
ist
[
Sequence
],
arrival_time
:
float
,
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -686,7 +686,7 @@ class SequenceGroup:
...
@@ -686,7 +686,7 @@ class SequenceGroup:
return
self
.
first_seq
.
prompt
return
self
.
first_seq
.
prompt
@
property
@
property
def
prompt_token_ids
(
self
)
->
L
ist
[
int
]:
def
prompt_token_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
first_seq
.
prompt_token_ids
return
self
.
first_seq
.
prompt_token_ids
@
property
@
property
...
@@ -698,7 +698,7 @@ class SequenceGroup:
...
@@ -698,7 +698,7 @@ class SequenceGroup:
if
self
.
encoder_seq
is
not
None
else
None
)
if
self
.
encoder_seq
is
not
None
else
None
)
@
property
@
property
def
encoder_prompt_token_ids
(
self
)
->
Optional
[
L
ist
[
int
]]:
def
encoder_prompt_token_ids
(
self
)
->
Optional
[
l
ist
[
int
]]:
# There are either 0 or 1 encoder sequences
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# If one is present, its prompt token ids are
# distinct from the decoder's.
# distinct from the decoder's.
...
@@ -706,7 +706,7 @@ class SequenceGroup:
...
@@ -706,7 +706,7 @@ class SequenceGroup:
if
self
.
encoder_seq
is
not
None
else
None
)
if
self
.
encoder_seq
is
not
None
else
None
)
@
property
@
property
def
token_type_ids
(
self
)
->
Optional
[
L
ist
[
int
]]:
def
token_type_ids
(
self
)
->
Optional
[
l
ist
[
int
]]:
return
self
.
first_seq
.
token_type_ids
return
self
.
first_seq
.
token_type_ids
@
property
@
property
...
@@ -726,7 +726,7 @@ class SequenceGroup:
...
@@ -726,7 +726,7 @@ class SequenceGroup:
return
{}
return
{}
@
property
@
property
def
mm_processor_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mm_processor_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
if
self
.
first_seq
.
multi_modal_data
:
if
self
.
first_seq
.
multi_modal_data
:
return
self
.
first_seq
.
mm_processor_kwargs
return
self
.
first_seq
.
mm_processor_kwargs
elif
self
.
encoder_seq
is
not
None
:
elif
self
.
encoder_seq
is
not
None
:
...
@@ -823,7 +823,7 @@ class SequenceGroup:
...
@@ -823,7 +823,7 @@ class SequenceGroup:
def
get_seqs
(
def
get_seqs
(
self
,
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
L
ist
[
Sequence
]:
)
->
l
ist
[
Sequence
]:
if
status
is
None
:
if
status
is
None
:
return
self
.
seqs
return
self
.
seqs
...
@@ -838,7 +838,7 @@ class SequenceGroup:
...
@@ -838,7 +838,7 @@ class SequenceGroup:
def
get_encoder_seq
(
self
)
->
Optional
[
Sequence
]:
def
get_encoder_seq
(
self
)
->
Optional
[
Sequence
]:
return
self
.
encoder_seq
return
self
.
encoder_seq
def
get_finished_seqs
(
self
)
->
L
ist
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
l
ist
[
Sequence
]:
if
self
.
is_single_seq
:
if
self
.
is_single_seq
:
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
return
self
.
seqs
if
self
.
first_seq
.
is_finished
()
else
[]
...
@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
...
@@ -897,13 +897,13 @@ class SequenceGroupMetadataDelta(
After sending the first SequenceGroupMetadata, vLLM scheduler
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
only sends delta to reduce the data payload size.
"""
"""
seq_data_delta
:
D
ict
[
int
,
SequenceDataDelta
]
seq_data_delta
:
d
ict
[
int
,
SequenceDataDelta
]
request_id
:
str
request_id
:
str
block_tables
:
D
ict
[
int
,
L
ist
[
int
]]
block_tables
:
d
ict
[
int
,
l
ist
[
int
]]
is_prompt
:
bool
is_prompt
:
bool
do_sample
:
bool
=
True
do_sample
:
bool
=
True
token_chunk_size
:
Optional
[
int
]
=
None
token_chunk_size
:
Optional
[
int
]
=
None
computed_block_nums
:
Optional
[
L
ist
[
int
]]
=
None
computed_block_nums
:
Optional
[
l
ist
[
int
]]
=
None
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
default_factory
=
lambda
:
SequenceGroupState
())
default_factory
=
lambda
:
SequenceGroupState
())
...
@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
...
@@ -947,23 +947,23 @@ class SequenceGroupMetadata(
request_id
:
str
request_id
:
str
is_prompt
:
bool
is_prompt
:
bool
seq_data
:
D
ict
[
int
,
SequenceData
]
seq_data
:
d
ict
[
int
,
SequenceData
]
sampling_params
:
Optional
[
SamplingParams
]
sampling_params
:
Optional
[
SamplingParams
]
block_tables
:
D
ict
[
int
,
L
ist
[
int
]]
block_tables
:
d
ict
[
int
,
l
ist
[
int
]]
do_sample
:
bool
=
True
do_sample
:
bool
=
True
pooling_params
:
Optional
[
PoolingParams
]
=
None
pooling_params
:
Optional
[
PoolingParams
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
computed_block_nums
:
Optional
[
L
ist
[
int
]]
=
None
computed_block_nums
:
Optional
[
l
ist
[
int
]]
=
None
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
default_factory
=
lambda
:
SequenceGroupState
())
default_factory
=
lambda
:
SequenceGroupState
())
# "MultiModalDataDict" types. We have to use Any due to msgspec
# "MultiModalDataDict" types. We have to use Any due to msgspec
# doesn't allow to have union of 2 different dicts.
# doesn't allow to have union of 2 different dicts.
token_type_ids
:
Optional
[
L
ist
[
int
]]
=
None
token_type_ids
:
Optional
[
l
ist
[
int
]]
=
None
multi_modal_data
:
Optional
[
Any
]
=
None
multi_modal_data
:
Optional
[
Any
]
=
None
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
multi_modal_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
=
None
mm_processor_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
mm_processor_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
cross_block_table
:
Optional
[
L
ist
[
int
]]
=
None
cross_block_table
:
Optional
[
l
ist
[
int
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
token_chunk_size
:
Optional
[
int
]
=
None
token_chunk_size
:
Optional
[
int
]
=
None
...
@@ -1042,7 +1042,7 @@ class SequenceOutput(
...
@@ -1042,7 +1042,7 @@ class SequenceOutput(
"""
"""
parent_seq_id
:
int
parent_seq_id
:
int
output_token
:
int
output_token
:
int
logprobs
:
D
ict
[
int
,
Logprob
]
logprobs
:
d
ict
[
int
,
Logprob
]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
...
@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
...
@@ -1076,7 +1076,7 @@ class CompletionSequenceGroupOutput(
array_like
=
True
):
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The model output associated with a completion sequence group."""
"""The model output associated with a completion sequence group."""
__metaclass__
=
SequenceGroupOutput
__metaclass__
=
SequenceGroupOutput
samples
:
L
ist
[
SequenceOutput
]
samples
:
l
ist
[
SequenceOutput
]
# Prompt logprob for each prompt query token.
# Prompt logprob for each prompt query token.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
prompt_logprobs
:
Optional
[
PromptLogprobs
]
...
@@ -1119,7 +1119,7 @@ class IntermediateTensors:
...
@@ -1119,7 +1119,7 @@ class IntermediateTensors:
contains the hidden states and residuals for a request.
contains the hidden states and residuals for a request.
"""
"""
tensors
:
D
ict
[
str
,
torch
.
Tensor
]
tensors
:
d
ict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# manually define this function, so that
...
@@ -1155,7 +1155,7 @@ class PoolerOutput(
...
@@ -1155,7 +1155,7 @@ class PoolerOutput(
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
"""The output from a pooling operation in the pooling model."""
outputs
:
L
ist
[
PoolingSequenceGroupOutput
]
outputs
:
l
ist
[
PoolingSequenceGroupOutput
]
def
__getitem__
(
self
,
idx
:
int
)
->
PoolingSequenceGroupOutput
:
def
__getitem__
(
self
,
idx
:
int
)
->
PoolingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
...
@@ -1172,7 +1172,7 @@ class PoolerOutput(
...
@@ -1172,7 +1172,7 @@ class PoolerOutput(
def
get_all_seq_ids
(
def
get_all_seq_ids
(
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
])
->
L
ist
[
int
]:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
])
->
l
ist
[
int
]:
"""Given a list of SequenceGroupMetadata, create a list of all
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
sequence ids.
"""
"""
...
@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
...
@@ -1180,13 +1180,13 @@ def get_all_seq_ids(
def
get_all_seq_ids_and_request_ids
(
def
get_all_seq_ids_and_request_ids
(
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
]
)
->
T
uple
[
L
ist
[
int
],
D
ict
[
str
,
S
et
[
int
]]]:
)
->
t
uple
[
l
ist
[
int
],
d
ict
[
str
,
s
et
[
int
]]]:
"""Given a list of SequenceGroupMetadata, create a list of all
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
sequence ids.
"""
"""
seq_ids
:
L
ist
[
int
]
=
[]
seq_ids
:
l
ist
[
int
]
=
[]
request_id_seq_ids_mapping
:
D
efault
D
ict
[
str
,
S
et
[
int
]]
=
defaultdict
(
set
)
request_id_seq_ids_mapping
:
d
efault
d
ict
[
str
,
s
et
[
int
]]
=
defaultdict
(
set
)
for
sg
in
seq_group_metadata_list
:
for
sg
in
seq_group_metadata_list
:
for
seq_id
in
sg
.
seq_data
:
for
seq_id
in
sg
.
seq_data
:
seq_ids
.
append
(
seq_id
)
seq_ids
.
append
(
seq_id
)
...
@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1206,14 +1206,14 @@ class HiddenStates(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
# all tokens, whereas for decode step, it use used for last accepted tokens.
hidden_states
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
# The sequence group metadata list. Only needed for decode step.
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list
:
Optional
[
L
ist
[
SequenceGroupMetadata
]]
=
None
seq_group_metadata_list
:
Optional
[
l
ist
[
SequenceGroupMetadata
]]
=
None
# Scorer hidden states of the 2nd last token proposed by the proposer (
# Scorer hidden states of the 2nd last token proposed by the proposer (
# irrespective of whether it was accepted or not). Only used for cases when
# irrespective of whether it was accepted or not). Only used for cases when
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# last proposed token is accepted (i.e., in case of bonus tokens). For the
# case of no bonus tokens, these are ignored.
# case of no bonus tokens, these are ignored.
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
_seq_ids
:
L
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
_seq_ids
:
l
ist
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
seq_group_metadata_list
is
not
None
:
if
self
.
seq_group_metadata_list
is
not
None
:
...
@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1221,12 +1221,12 @@ class HiddenStates(msgspec.Struct, array_like=True,
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
@
property
@
property
def
seq_ids
(
self
)
->
L
ist
[
int
]:
def
seq_ids
(
self
)
->
l
ist
[
int
]:
return
self
.
_seq_ids
return
self
.
_seq_ids
def
update
(
self
,
def
update
(
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
],
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
):
second_last_token_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
):
"""Update hidden states from target model invocation. Only used for
"""Update hidden states from target model invocation. Only used for
decode steps"""
decode steps"""
...
@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
...
@@ -1244,7 +1244,7 @@ class HiddenStates(msgspec.Struct, array_like=True,
])
])
def
prune
(
self
,
def
prune
(
self
,
seq_group_metadata_list
:
L
ist
[
SequenceGroupMetadata
])
->
None
:
seq_group_metadata_list
:
l
ist
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
"""
# Currently this prunes all seq_ids not present in
# Currently this prunes all seq_ids not present in
...
@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
...
@@ -1287,16 +1287,16 @@ class ExecuteModelRequest(
"""The model execution request, containing CPU metadata only. The LLM
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
# The sequence group metadata list.
seq_group_metadata_list
:
L
ist
[
Union
[
SequenceGroupMetadata
,
seq_group_metadata_list
:
l
ist
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
SequenceGroupMetadataDelta
]]
# Blocks to swap in. List of CPU -> GPU block number.
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in
:
L
ist
[
T
uple
[
int
,
blocks_to_swap_in
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to swap out. List of GPU -> CPU block number.
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out
:
L
ist
[
T
uple
[
int
,
blocks_to_swap_out
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to copy. Source to dest block.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
L
ist
[
T
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
blocks_to_copy
:
l
ist
[
t
uple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Virtual engine ID for pipeline parallel.
# Virtual engine ID for pipeline parallel.
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
...
@@ -1310,7 +1310,7 @@ class ExecuteModelRequest(
# The step index for spec model input.
# The step index for spec model input.
spec_step_idx
:
Optional
[
int
]
=
None
spec_step_idx
:
Optional
[
int
]
=
None
# Finished request ids since last step.
# Finished request ids since last step.
finished_requests_ids
:
L
ist
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
l
ist
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
# Async callback
...
@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
...
@@ -1344,7 +1344,7 @@ class ExecuteModelRequest(
return
state
.
current_step
return
state
.
current_step
def
clone
(
def
clone
(
self
,
seq_group_metadata_list
:
L
ist
[
Union
[
SequenceGroupMetadata
,
self
,
seq_group_metadata_list
:
l
ist
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
SequenceGroupMetadataDelta
]]
)
->
"ExecuteModelRequest"
:
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
"""Clone the request with a new sequence group metadata list."""
...
@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
...
@@ -1371,13 +1371,13 @@ class SequenceGroupBase:
assembled_seq_group
:
Optional
[
SequenceGroup
]
=
None
assembled_seq_group
:
Optional
[
SequenceGroup
]
=
None
# seq id to a unique index inside this group
# seq id to a unique index inside this group
seq_id_to_index
:
D
ict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
seq_id_to_index
:
d
ict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
# seq ids to be finished
# seq ids to be finished
to_be_finished
:
D
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
to_be_finished
:
d
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
# seq id to finished sequences
# seq id to finished sequences
finished_reqs
:
D
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
finished_reqs
:
d
ict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
streaming
:
bool
=
False
streaming
:
bool
=
False
...
...
vllm/tracing.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
import
os
from
typing
import
Mapping
,
Optional
from
collections.abc
import
Mapping
from
typing
import
Optional
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
run_once
from
vllm.utils
import
run_once
...
...
vllm/utils.py
View file @
cf069aa8
...
@@ -28,12 +28,12 @@ import warnings
...
@@ -28,12 +28,12 @@ import warnings
import
weakref
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections.abc
import
Hashable
,
Iterable
,
Mapping
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
Mapping
)
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
Optional
,
TypeVar
,
Union
)
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cloudpickle
import
cloudpickle
...
@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None],
...
@@ -400,7 +400,7 @@ def _next_task(iterator: AsyncGenerator[T, None],
async
def
merge_async_iterators
(
async
def
merge_async_iterators
(
*
iterators
:
AsyncGenerator
[
T
,
*
iterators
:
AsyncGenerator
[
T
,
None
],
)
->
AsyncGenerator
[
T
uple
[
int
,
T
],
None
]:
None
],
)
->
AsyncGenerator
[
t
uple
[
int
,
T
],
None
]:
"""Merge multiple asynchronous iterators into a single iterator.
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
This method handle the case where some iterators finish before others.
...
@@ -433,7 +433,7 @@ async def merge_async_iterators(
...
@@ -433,7 +433,7 @@ async def merge_async_iterators(
async
def
collect_from_async_generator
(
async
def
collect_from_async_generator
(
iterator
:
AsyncGenerator
[
T
,
None
])
->
L
ist
[
T
]:
iterator
:
AsyncGenerator
[
T
,
None
])
->
l
ist
[
T
]:
"""Collect all items from an async generator into a list."""
"""Collect all items from an async generator into a list."""
items
=
[]
items
=
[]
async
for
item
in
iterator
:
async
for
item
in
iterator
:
...
@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]:
...
@@ -560,7 +560,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]:
return
None
return
None
def
update_environment_variables
(
envs
:
D
ict
[
str
,
str
]):
def
update_environment_variables
(
envs
:
d
ict
[
str
,
str
]):
for
k
,
v
in
envs
.
items
():
for
k
,
v
in
envs
.
items
():
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
logger
.
warning
(
logger
.
warning
(
...
@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]):
...
@@ -569,7 +569,7 @@ def update_environment_variables(envs: Dict[str, str]):
os
.
environ
[
k
]
=
v
os
.
environ
[
k
]
=
v
def
chunk_list
(
lst
:
L
ist
[
T
],
chunk_size
:
int
):
def
chunk_list
(
lst
:
l
ist
[
T
],
chunk_size
:
int
):
"""Yield successive chunk_size chunks from lst."""
"""Yield successive chunk_size chunks from lst."""
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
):
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
):
yield
lst
[
i
:
i
+
chunk_size
]
yield
lst
[
i
:
i
+
chunk_size
]
...
@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash(
...
@@ -642,7 +642,7 @@ def create_kv_caches_with_random_flash(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
T
uple
[
L
ist
[
torch
.
Tensor
],
L
ist
[
torch
.
Tensor
]]:
)
->
t
uple
[
l
ist
[
torch
.
Tensor
],
l
ist
[
torch
.
Tensor
]]:
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash(
...
@@ -650,8 +650,8 @@ def create_kv_caches_with_random_flash(
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
key_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
key_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
...
@@ -679,7 +679,7 @@ def create_kv_caches_with_random(
...
@@ -679,7 +679,7 @@ def create_kv_caches_with_random(
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
T
uple
[
L
ist
[
torch
.
Tensor
],
L
ist
[
torch
.
Tensor
]]:
)
->
t
uple
[
l
ist
[
torch
.
Tensor
],
l
ist
[
torch
.
Tensor
]]:
if
cache_dtype
==
"fp8"
and
head_size
%
16
:
if
cache_dtype
==
"fp8"
and
head_size
%
16
:
raise
ValueError
(
raise
ValueError
(
...
@@ -693,7 +693,7 @@ def create_kv_caches_with_random(
...
@@ -693,7 +693,7 @@ def create_kv_caches_with_random(
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
key_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
...
@@ -708,7 +708,7 @@ def create_kv_caches_with_random(
...
@@ -708,7 +708,7 @@ def create_kv_caches_with_random(
key_caches
.
append
(
key_cache
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
:
L
ist
[
torch
.
Tensor
]
=
[]
value_caches
:
l
ist
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
...
@@ -754,7 +754,7 @@ class DeviceMemoryProfiler:
...
@@ -754,7 +754,7 @@ class DeviceMemoryProfiler:
def
make_ndarray_with_pad
(
def
make_ndarray_with_pad
(
x
:
L
ist
[
L
ist
[
T
]],
x
:
l
ist
[
l
ist
[
T
]],
pad
:
T
,
pad
:
T
,
dtype
:
npt
.
DTypeLike
,
dtype
:
npt
.
DTypeLike
,
*
,
*
,
...
@@ -779,7 +779,7 @@ def make_ndarray_with_pad(
...
@@ -779,7 +779,7 @@ def make_ndarray_with_pad(
def
make_tensor_with_pad
(
def
make_tensor_with_pad
(
x
:
L
ist
[
L
ist
[
T
]],
x
:
l
ist
[
l
ist
[
T
]],
pad
:
T
,
pad
:
T
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
*
,
*
,
...
@@ -831,7 +831,7 @@ def is_list_of(
...
@@ -831,7 +831,7 @@ def is_list_of(
typ
:
Union
[
type
[
T
],
tuple
[
type
[
T
],
...]],
typ
:
Union
[
type
[
T
],
tuple
[
type
[
T
],
...]],
*
,
*
,
check
:
Literal
[
"first"
,
"all"
]
=
"first"
,
check
:
Literal
[
"first"
,
"all"
]
=
"first"
,
)
->
TypeIs
[
L
ist
[
T
]]:
)
->
TypeIs
[
l
ist
[
T
]]:
if
not
isinstance
(
value
,
list
):
if
not
isinstance
(
value
,
list
):
return
False
return
False
...
@@ -843,8 +843,8 @@ def is_list_of(
...
@@ -843,8 +843,8 @@ def is_list_of(
assert_never
(
check
)
assert_never
(
check
)
JSONTree
=
Union
[
D
ict
[
str
,
"JSONTree[T]"
],
L
ist
[
"JSONTree[T]"
],
JSONTree
=
Union
[
d
ict
[
str
,
"JSONTree[T]"
],
l
ist
[
"JSONTree[T]"
],
T
uple
[
"JSONTree[T]"
,
...],
T
]
t
uple
[
"JSONTree[T]"
,
...],
T
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
"""A nested JSON structure where the leaves need not be JSON-serializable."""
...
@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
...
@@ -859,7 +859,7 @@ def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
return
func
(
value
)
return
func
(
value
)
def
flatten_2d_lists
(
lists
:
L
ist
[
L
ist
[
T
]])
->
L
ist
[
T
]:
def
flatten_2d_lists
(
lists
:
l
ist
[
l
ist
[
T
]])
->
l
ist
[
T
]:
"""Flatten a list of lists to a single list."""
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
...
@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
...
@@ -1226,7 +1226,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return
value
return
value
def
_pull_args_from_config
(
self
,
args
:
L
ist
[
str
])
->
L
ist
[
str
]:
def
_pull_args_from_config
(
self
,
args
:
l
ist
[
str
])
->
l
ist
[
str
]:
"""Method to pull arguments specified in the config file
"""Method to pull arguments specified in the config file
into the command-line args variable.
into the command-line args variable.
...
@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
...
@@ -1291,7 +1291,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return
args
return
args
def
_load_config_file
(
self
,
file_path
:
str
)
->
L
ist
[
str
]:
def
_load_config_file
(
self
,
file_path
:
str
)
->
l
ist
[
str
]:
"""Loads a yaml file and returns the key value pairs as a
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
flattened list with argparse like pattern
```yaml
```yaml
...
@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
...
@@ -1313,9 +1313,9 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
%s supplied"
,
extension
)
%s supplied"
,
extension
)
# only expecting a flat dictionary of atomic types
# only expecting a flat dictionary of atomic types
processed_args
:
L
ist
[
str
]
=
[]
processed_args
:
l
ist
[
str
]
=
[]
config
:
D
ict
[
str
,
Union
[
int
,
str
]]
=
{}
config
:
d
ict
[
str
,
Union
[
int
,
str
]]
=
{}
try
:
try
:
with
open
(
file_path
)
as
config_file
:
with
open
(
file_path
)
as
config_file
:
config
=
yaml
.
safe_load
(
config_file
)
config
=
yaml
.
safe_load
(
config_file
)
...
@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs(
...
@@ -1399,7 +1399,7 @@ def resolve_mm_processor_kwargs(
*
,
*
,
requires_kw_only
:
bool
=
True
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
allow_var_kwargs
:
bool
=
False
,
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
given; otherwise no filtering is done), then merges the kwarg dicts,
...
@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides(
...
@@ -1440,7 +1440,7 @@ def get_allowed_kwarg_only_overrides(
*
,
*
,
requires_kw_only
:
bool
=
True
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
allow_var_kwargs
:
bool
=
False
,
)
->
D
ict
[
str
,
Any
]:
)
->
d
ict
[
str
,
Any
]:
"""
"""
Given a callable which has one or more keyword only params and a dict
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
mapping param names to values, drop values that can be not be kwarg
...
@@ -1531,9 +1531,9 @@ class AtomicCounter:
...
@@ -1531,9 +1531,9 @@ class AtomicCounter:
# Adapted from: https://stackoverflow.com/a/47212782/5082708
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
[
str
,
T
],
Generic
[
T
]):
class
LazyDict
(
Mapping
[
str
,
T
],
Generic
[
T
]):
def
__init__
(
self
,
factory
:
D
ict
[
str
,
Callable
[[],
T
]]):
def
__init__
(
self
,
factory
:
d
ict
[
str
,
Callable
[[],
T
]]):
self
.
_factory
=
factory
self
.
_factory
=
factory
self
.
_dict
:
D
ict
[
str
,
T
]
=
{}
self
.
_dict
:
d
ict
[
str
,
T
]
=
{}
def
__getitem__
(
self
,
key
:
str
)
->
T
:
def
__getitem__
(
self
,
key
:
str
)
->
T
:
if
key
not
in
self
.
_dict
:
if
key
not
in
self
.
_dict
:
...
@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
...
@@ -1552,9 +1552,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return
len
(
self
.
_factory
)
return
len
(
self
.
_factory
)
class
ClassRegistry
(
UserDict
[
T
ype
[
T
],
_V
]):
class
ClassRegistry
(
UserDict
[
t
ype
[
T
],
_V
]):
def
__getitem__
(
self
,
key
:
T
ype
[
T
])
->
_V
:
def
__getitem__
(
self
,
key
:
t
ype
[
T
])
->
_V
:
for
cls
in
key
.
mro
():
for
cls
in
key
.
mro
():
if
cls
in
self
.
data
:
if
cls
in
self
.
data
:
return
self
.
data
[
cls
]
return
self
.
data
[
cls
]
...
@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
...
@@ -1584,8 +1584,8 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
def
weak_ref_tensors
(
def
weak_ref_tensors
(
tensors
:
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
],
T
uple
[
torch
.
Tensor
]]
tensors
:
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
],
t
uple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
],
T
uple
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
],
t
uple
[
torch
.
Tensor
]]:
"""
"""
Convenience function to create weak references to tensors,
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
for single tensor, list of tensors or tuple of tensors.
...
@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
...
@@ -1857,7 +1857,7 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def
direct_register_custom_op
(
def
direct_register_custom_op
(
op_name
:
str
,
op_name
:
str
,
op_func
:
Callable
,
op_func
:
Callable
,
mutates_args
:
L
ist
[
str
],
mutates_args
:
l
ist
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
dispatch_key
:
str
=
"CUDA"
,
...
@@ -2177,8 +2177,8 @@ def get_mp_context():
...
@@ -2177,8 +2177,8 @@ def get_mp_context():
def
bind_kv_cache
(
def
bind_kv_cache
(
ctx
:
D
ict
[
str
,
Any
],
ctx
:
d
ict
[
str
,
Any
],
kv_cache
:
L
ist
[
L
ist
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
kv_cache
:
l
ist
[
l
ist
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
)
->
None
:
)
->
None
:
# Bind the kv_cache tensor to Attention modules, similar to
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
...
@@ -2210,8 +2210,8 @@ def bind_kv_cache(
...
@@ -2210,8 +2210,8 @@ def bind_kv_cache(
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
def
run_method
(
obj
:
Any
,
method
:
Union
[
str
,
bytes
,
Callable
],
args
:
T
uple
[
Any
],
def
run_method
(
obj
:
Any
,
method
:
Union
[
str
,
bytes
,
Callable
],
args
:
t
uple
[
Any
],
kwargs
:
D
ict
[
str
,
Any
])
->
Any
:
kwargs
:
d
ict
[
str
,
Any
])
->
Any
:
"""
"""
Run a method of an object with the given arguments and keyword arguments.
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is string, it will be converted to a method using getattr.
...
@@ -2263,7 +2263,7 @@ def import_pynvml():
...
@@ -2263,7 +2263,7 @@ def import_pynvml():
return
pynvml
return
pynvml
def
warn_for_unimplemented_methods
(
cls
:
T
ype
[
T
])
->
T
ype
[
T
]:
def
warn_for_unimplemented_methods
(
cls
:
t
ype
[
T
])
->
t
ype
[
T
]:
"""
"""
A replacement for `abc.ABC`.
A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate
When we use `abc.ABC`, subclasses will fail to instantiate
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -30,7 +30,7 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
@
staticmethod
@
staticmethod
def
get_supported_head_sizes
()
->
L
ist
[
int
]:
def
get_supported_head_sizes
()
->
l
ist
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
@
staticmethod
...
@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -38,15 +38,15 @@ class FlashAttentionBackend(AttentionBackend):
return
"FLASH_ATTN_VLLM_V1"
return
"FLASH_ATTN_VLLM_V1"
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
T
ype
[
"FlashAttentionImpl"
]:
def
get_impl_cls
()
->
t
ype
[
"FlashAttentionImpl"
]:
return
FlashAttentionImpl
return
FlashAttentionImpl
@
staticmethod
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
return
FlashAttentionMetadata
@
staticmethod
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"FlashAttentionMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
return
FlashAttentionMetadataBuilder
@
staticmethod
@
staticmethod
...
@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -55,7 +55,7 @@ class FlashAttentionBackend(AttentionBackend):
block_size
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
)
->
T
uple
[
int
,
...]:
)
->
t
uple
[
int
,
...]:
if
block_size
%
16
!=
0
:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
...
@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -158,10 +158,10 @@ class FlashAttentionImpl(AttentionImpl):
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
...
@@ -381,7 +381,7 @@ def cascade_attention(
...
@@ -381,7 +381,7 @@ def cascade_attention(
max_kv_len
:
int
,
max_kv_len
:
int
,
softmax_scale
:
float
,
softmax_scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
sliding_window
:
T
uple
[
int
,
int
],
sliding_window
:
t
uple
[
int
,
int
],
logits_soft_cap
:
float
,
logits_soft_cap
:
float
,
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
cf069aa8
...
@@ -195,8 +195,7 @@ return curr_o @ W_O
...
@@ -195,8 +195,7 @@ return curr_o @ W_O
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
Type
,
TypeVar
)
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
QuantizationStrategy
...
@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend):
...
@@ -250,11 +249,11 @@ class MLACommonBackend(AttentionBackend):
return
"TRITON_MLA_VLLM_V1"
return
"TRITON_MLA_VLLM_V1"
@
staticmethod
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"AttentionMetadata"
]:
return
MLACommonMetadata
return
MLACommonMetadata
@
staticmethod
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"MLACommonMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"MLACommonMetadataBuilder"
]:
return
MLACommonMetadataBuilder
return
MLACommonMetadataBuilder
@
staticmethod
@
staticmethod
...
@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend):
...
@@ -263,11 +262,11 @@ class MLACommonBackend(AttentionBackend):
block_size
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
head_size
:
int
,
)
->
T
uple
[
int
,
...]:
)
->
t
uple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
@
staticmethod
def
get_supported_head_sizes
()
->
L
ist
[
int
]:
def
get_supported_head_sizes
()
->
l
ist
[
int
]:
return
[
576
]
return
[
576
]
@
staticmethod
@
staticmethod
...
@@ -317,8 +316,8 @@ class MLACommonMetadata:
...
@@ -317,8 +316,8 @@ class MLACommonMetadata:
has_context
:
bool
=
False
has_context
:
bool
=
False
context_chunk_cu_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_cu_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_starts
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_starts
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_seq_tot
:
Optional
[
L
ist
[
int
]]
=
None
context_chunk_seq_tot
:
Optional
[
l
ist
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
L
ist
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
l
ist
[
int
]]
=
None
chunked_prefill_workspace
:
Optional
[
torch
.
Tensor
]
=
None
chunked_prefill_workspace
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
...
@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -538,10 +537,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
# MLA Specific Arguments
# MLA Specific Arguments
...
@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -634,7 +633,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
#
#
# returns input_group_shape, weight_group_shape
# returns input_group_shape, weight_group_shape
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
T
uple
[
T
uple
[
int
,
int
],
T
uple
[
int
,
int
]]:
t
uple
[
t
uple
[
int
,
int
],
t
uple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
:
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
weight_block_size
=
\
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Optional
import
torch
import
torch
...
@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend):
...
@@ -25,21 +25,21 @@ class FlashMLABackend(MLACommonBackend):
return
"FLASHMLA_VLLM_V1"
return
"FLASHMLA_VLLM_V1"
@
staticmethod
@
staticmethod
def
get_metadata_cls
()
->
T
ype
[
"FlashMLAMetadata"
]:
def
get_metadata_cls
()
->
t
ype
[
"FlashMLAMetadata"
]:
return
FlashMLAMetadata
return
FlashMLAMetadata
@
staticmethod
@
staticmethod
def
get_builder_cls
()
->
T
ype
[
"FlashMLAMetadataBuilder"
]:
def
get_builder_cls
()
->
t
ype
[
"FlashMLAMetadataBuilder"
]:
return
FlashMLAMetadataBuilder
return
FlashMLAMetadataBuilder
@
staticmethod
@
staticmethod
def
get_impl_cls
()
->
T
ype
[
"FlashMLAImpl"
]:
def
get_impl_cls
()
->
t
ype
[
"FlashMLAImpl"
]:
return
FlashMLAImpl
return
FlashMLAImpl
@
dataclass
@
dataclass
class
FlashMLAMetadata
(
MLACommonMetadata
):
class
FlashMLAMetadata
(
MLACommonMetadata
):
decode_tile_scheduler_metadata
:
Optional
[
T
uple
[
torch
.
Tensor
,
decode_tile_scheduler_metadata
:
Optional
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
torch
.
Tensor
]]
=
None
decode_num_splits
:
Optional
[
torch
.
Tensor
]
=
None
decode_num_splits
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -76,10 +76,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
L
ist
[
float
]],
alibi_slopes
:
Optional
[
l
ist
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
D
ict
[
str
,
Any
]],
blocksparse_params
:
Optional
[
d
ict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
# MLA Specific Arguments
# MLA Specific Arguments
...
...
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment