Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
290 additions
and
104 deletions
+290
-104
vllm/pooling_params.py
vllm/pooling_params.py
+21
-2
vllm/reasoning/granite_reasoning_parser.py
vllm/reasoning/granite_reasoning_parser.py
+1
-1
vllm/sampling_params.py
vllm/sampling_params.py
+5
-4
vllm/third_party/pynvml.py
vllm/third_party/pynvml.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-1
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+4
-1
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+3
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+7
-0
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+26
-11
vllm/utils.py
vllm/utils.py
+48
-15
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+3
-3
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+8
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+5
-5
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+8
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+41
-32
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+64
-8
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+26
-6
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+2
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+8
-7
No files found.
vllm/pooling_params.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
class
PoolingParams
(
msgspec
.
Struct
,
...
...
@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions
:
Optional
[
int
]
=
None
additional_data
:
Optional
[
Any
]
=
None
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
dimensions
is
not
None
:
if
not
model_config
.
is_matryoshka
:
raise
ValueError
(
f
'Model "
{
model_config
.
served_model_name
}
" does not '
f
'support matryoshka representation, '
f
'changing output dimensions will lead to poor results.'
)
if
self
.
dimensions
<
1
:
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
f
"dimensions=
{
self
.
dimensions
}
, "
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
vllm/reasoning/granite_reasoning_parser.py
View file @
9c4ecf15
...
...
@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser):
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
request (ChatCompletionReq
u
est): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
...
...
vllm/sampling_params.py
View file @
9c4ecf15
...
...
@@ -101,7 +101,7 @@ class RequestOutputKind(Enum):
CUMULATIVE
=
0
# Return only deltas in each RequestOutput
DELTA
=
1
# Do not return intermediate RequestOuput
s
# Do not return intermediate RequestOu
t
put
FINAL_ONLY
=
2
...
...
@@ -385,9 +385,10 @@ class SamplingParams(
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
f
"
{
self
.
frequency_penalty
}
."
)
if
not
0.0
<
self
.
repetition_penalty
<=
2.0
:
raise
ValueError
(
"repetition_penalty must be in (0, 2], got "
f
"
{
self
.
repetition_penalty
}
."
)
if
self
.
repetition_penalty
<=
0.0
:
raise
ValueError
(
"repetition_penalty must be greater than zero, got "
f
"
{
self
.
repetition_penalty
}
."
)
if
self
.
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
...
...
vllm/third_party/pynvml.py
View file @
9c4ecf15
...
...
@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure):
e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"}
to produce nicer output.
Default fo
m
ratting string for all fields can be set with key "<default>" like:
Default for
m
atting string for all fields can be set with key "<default>" like:
_fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz.
If not set it's assumed to be just "%s"
...
...
vllm/transformers_utils/config.py
View file @
9c4ecf15
...
...
@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
def
get_hf_image_processor_config
(
model
:
Union
[
str
,
Path
],
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Dict
[
str
,
Any
]:
...
...
@@ -721,7 +722,10 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if
check_gguf_file
(
model
):
model
=
Path
(
model
).
parent
return
get_image_processor_config
(
model
,
revision
=
revision
,
**
kwargs
)
return
get_image_processor_config
(
model
,
token
=
hf_token
,
revision
=
revision
,
**
kwargs
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
...
vllm/transformers_utils/configs/eagle.py
View file @
9c4ecf15
...
...
@@ -5,6 +5,7 @@ from typing import Optional, Union
from
transformers
import
AutoConfig
,
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
...
...
@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
self
.
truncated_vocab_size
=
self
.
model
.
vocab_size
if
\
truncated_vocab_size
is
None
else
truncated_vocab_size
if
"architectures"
not
in
kwargs
:
if
not
envs
.
VLLM_USE_V1
:
kwargs
[
"architectures"
]
=
[
"EAGLEModel"
]
else
:
kwargs
[
"architectures"
]
=
[
"EagleLlamaForCausalLM"
]
super
().
__init__
(
**
kwargs
)
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
,
"validate_request_params"
]
vllm/transformers_utils/tokenizers/mistral.py
View file @
9c4ecf15
...
...
@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request
.
messages
[
i
][
"tool_call_id"
]
=
tool_call_id
def
validate_request_params
(
request
:
"ChatCompletionRequest"
):
if
(
request
.
skip_special_tokens
is
not
None
and
not
request
.
skip_special_tokens
):
raise
ValueError
(
"skip_special_tokens=False is not supported "
"for Mistral tokenizers."
)
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
repo_cache
=
os
.
path
.
join
(
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
...
...
vllm/transformers_utils/utils.py
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
import
json
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
...
...
@@ -51,6 +52,26 @@ def modelscope_list_repo_files(
return
files
def
_maybe_json_dict
(
path
:
Union
[
str
,
PathLike
])
->
dict
[
str
,
str
]:
with
open
(
path
)
as
f
:
try
:
return
json
.
loads
(
f
.
read
())
except
Exception
:
return
dict
[
str
,
str
]()
def
_maybe_space_split_dict
(
path
:
Union
[
str
,
PathLike
])
->
dict
[
str
,
str
]:
parsed_dict
=
dict
[
str
,
str
]()
with
open
(
path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
strip
().
split
()
parsed_dict
[
model_name
]
=
redirect_name
except
Exception
:
pass
return
parsed_dict
@
cache
def
maybe_model_redirect
(
model
:
str
)
->
str
:
"""
...
...
@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str:
if
not
Path
(
model_redirect_path
).
exists
():
return
model
with
open
(
model_redirect_path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
split
(
"
\t
"
)
if
model
==
model_name
:
redirect_name
=
redirect_name
.
strip
()
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_name
)
return
redirect_name
except
Exception
:
pass
redirect_dict
=
(
_maybe_json_dict
(
model_redirect_path
)
or
_maybe_space_split_dict
(
model_redirect_path
))
if
(
redirect_model
:
=
redirect_dict
.
get
(
model
)):
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_model
)
return
redirect_model
return
model
vllm/utils.py
View file @
9c4ecf15
...
...
@@ -2,7 +2,6 @@
from
__future__
import
annotations
import
argparse
import
asyncio
import
concurrent
import
contextlib
...
...
@@ -25,6 +24,7 @@ import socket
import
subprocess
import
sys
import
tempfile
import
textwrap
import
threading
import
time
import
traceback
...
...
@@ -32,6 +32,8 @@ import types
import
uuid
import
warnings
import
weakref
from
argparse
import
(
Action
,
ArgumentDefaultsHelpFormatter
,
ArgumentParser
,
ArgumentTypeError
)
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
...
...
@@ -40,7 +42,7 @@ from dataclasses import dataclass, field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
types
import
MappingProxyType
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
from
uuid
import
uuid4
import
cachetools
...
...
@@ -53,6 +55,7 @@ import torch.types
import
yaml
import
zmq
import
zmq.asyncio
from
packaging
import
version
from
packaging.version
import
Version
from
torch.library
import
Library
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
...
...
@@ -1208,7 +1211,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return
wrapper
class
StoreBoolean
(
argparse
.
Action
):
class
StoreBoolean
(
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
if
values
.
lower
()
==
"true"
:
...
...
@@ -1220,15 +1223,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
argparse
.
ArgumentDefaultsHelpFormatter
):
class
SortedHelpFormatter
(
ArgumentDefaultsHelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def
_split_lines
(
self
,
text
,
width
):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline
=
re
.
compile
(
r
"(?<!\n)\n(?!\n)\s*"
)
multiple_newlines
=
re
.
compile
(
r
"\n{2,}\s*"
)
text
=
single_newline
.
sub
(
' '
,
text
)
lines
=
re
.
split
(
multiple_newlines
,
text
)
return
sum
([
textwrap
.
wrap
(
line
,
width
)
for
line
in
lines
],
[])
def
add_arguments
(
self
,
actions
):
actions
=
sorted
(
actions
,
key
=
lambda
x
:
x
.
option_strings
)
super
().
add_arguments
(
actions
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
class
FlexibleArgumentParser
(
ArgumentParser
):
"""ArgumentParser that allows both underscore and dash in names."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -1279,11 +1295,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value
=
int
(
value
)
except
ValueError
:
msg
=
"Port must be an integer"
raise
argparse
.
ArgumentTypeError
(
msg
)
from
None
raise
ArgumentTypeError
(
msg
)
from
None
if
not
(
1024
<=
value
<=
65535
):
raise
argparse
.
ArgumentTypeError
(
"Port must be between 1024 and 65535"
)
raise
ArgumentTypeError
(
"Port must be between 1024 and 65535"
)
return
value
...
...
@@ -1935,12 +1950,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
list
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
list
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
tags
:
Tuple
[
torch
.
Tag
,
...]
=
(),
):
"""
`torch.library.custom_op` can have significant overhead because it
...
...
@@ -1979,7 +1995,7 @@ def direct_register_custom_op(
import
torch._custom_op.impl
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
my_lib
=
target_lib
or
vllm_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
define
(
op_name
+
schema_str
,
tags
=
tags
)
my_lib
.
impl
(
op_name
,
op_func
,
dispatch_key
=
dispatch_key
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
...
...
@@ -2564,3 +2580,20 @@ def sha256(input) -> int:
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try
:
torch_version
=
version
.
parse
(
str
(
torch
.
__version__
))
return
torch_version
>=
version
.
parse
(
target
)
except
Exception
:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
target
)
vllm/v1/attention/backends/flash_attn.py
100755 → 100644
View file @
9c4ecf15
...
...
@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
...
...
@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
t
ensor
,
block_table
:
torch
.
T
ensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
t
ensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
T
ensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
...
@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
block_indices
=
block_indices
.
flatten
()
.
clip
(
max
=
block_table
.
shape
[
1
]
-
1
)
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
...
...
vllm/v1/attention/backends/mla/common.py
View file @
9c4ecf15
...
...
@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention(
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head
`kv_b_proj` is [W_UK; W_UV] concat
e
nated per head
`q_b_proj` is [W_UQ; W_QR] concat
e
nated per head
`out_proj` is W_O
...
...
@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
...
...
vllm/v1/attention/backends/pallas.py
View file @
9c4ecf15
...
...
@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
PallasAttentionBackend
(
AttentionBackend
):
...
...
@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
use_irope
:
logger
.
warning_once
(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Paged attention Pallas kernel does "
"not support block-sparse attention."
)
...
...
vllm/v1/core/block_pool.py
View file @
9c4ecf15
...
...
@@ -67,11 +67,11 @@ class BlockPool:
Returns:
The cached block if it exists, or None.
"""
if
block_hash
in
self
.
cached_block_hash_to_block
:
first_block_id
=
list
(
self
.
cached_block_hash_to_block
[
block_hash
].
keys
())[
0
]
return
self
.
cached_block_hash_to_block
[
block_hash
][
first
_block
_id
]
return
None
cached_blocks
=
self
.
cached_block_hash_to_block
.
get
(
block_hash
)
if
not
cached_blocks
:
return
None
first_block_id
=
next
(
iter
(
cached
_block
s
))
return
cached_blocks
[
first_block_id
]
def
cache_full_blocks
(
self
,
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
9c4ecf15
...
...
@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
_
,
max_tokens_per_mm_item
=
max
(
max_tokens_by_modality_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
if
(
scheduler_config
.
disable_chunked_mm_input
and
max_tokens_per_mm_item
>
scheduler_config
.
max_num_batched_tokens
):
raise
ValueError
(
"Chunked MM input disabled but max_tokens_per_mm_item "
f
"(
{
max_tokens_per_mm_item
}
) is larger than max_num_batched_tokens"
f
" (
{
scheduler_config
.
max_num_batched_tokens
}
). Please increase "
"max_num_batched_tokens."
)
encoder_compute_budget
=
max
(
scheduler_config
.
max_num_encoder_input_tokens
,
max_tokens_per_mm_item
)
encoder_cache_size
=
max
(
scheduler_config
.
encoder_cache_size
,
...
...
vllm/v1/core/kv_cache_manager.py
View file @
9c4ecf15
...
...
@@ -126,44 +126,46 @@ class KVCacheManager:
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
self
.
prefix_cache_stats
.
requests
+=
1
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
last_block_hash
=
None
# When the request requires prompt logprobs, we skip prefix caching.
if
request
.
sampling_params
.
prompt_logprobs
is
not
None
:
return
[],
0
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
last_block_hash
=
None
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
block_hashes
.
append
(
last_block_hash
)
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes
.
append
(
last_block_hash
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
else
:
# Skip cache hits for prompt logprobs
return
[],
0
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
def
allocate_slots
(
self
,
request
:
Request
,
num_tokens
:
int
,
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
,
num_lookahead_tokens
:
int
=
0
,
)
->
Optional
[
list
[
KVCacheBlock
]]:
"""Add slots for a request with new tokens to append.
...
...
@@ -173,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
-----------------------------------------------------------------------
...
...
@@ -210,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
len
(
new_computed_blocks
)
*
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
,
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
+
num_lookahead_tokens
,
self
.
block_size
)
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
len
(
new_computed_blocks
))
...
...
@@ -245,8 +251,11 @@ class KVCacheManager:
else
:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks
=
max
(
0
,
self
.
num_preallocate_blocks
-
num_lookahead_tokens
//
self
.
block_size
)
num_new_blocks
=
min
(
num_new_blocks
+
self
.
num_preallocate_blocks
,
num_new_blocks
+
num_preallocate_blocks
,
self
.
block_pool
.
get_num_free_blocks
(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
...
...
vllm/v1/core/kv_cache_utils.py
View file @
9c4ecf15
...
...
@@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
from
vllm.utils
import
GiB_bytes
,
sha256
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
...
...
@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
if
mm_positions
[
-
1
].
offset
+
mm_positions
[
-
1
].
length
<
start_token_idx
:
return
extra_keys
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
...
...
@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
]
[
"
offset
"
]
length
=
mm_positions
[
curr_mm_idx
]
[
"
length
"
]
offset
=
mm_positions
[
curr_mm_idx
]
.
offset
length
=
mm_positions
[
curr_mm_idx
]
.
length
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
...
...
@@ -460,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return
ret
def
estimate_max_model_len
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
int
:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def
fits_in_memory
(
model_len
:
int
)
->
bool
:
# Modify the max_model_len for this calculation
vllm_config
.
model_config
.
max_model_len
=
model_len
# Calculate memory needed for the given model length
memory_needed
=
sum
(
(
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
for
layer_spec
in
kv_cache_spec
.
values
()),
start
=
0
,
)
return
memory_needed
<=
available_memory
# Binary search for the maximum model length
current_max
=
vllm_config
.
model_config
.
max_model_len
left
,
right
=
1
,
current_max
# If even the smallest model length doesn't fit, return 0
if
not
fits_in_memory
(
left
):
return
0
# Binary search for the maximum model length that fits
result
=
1
while
left
<=
right
:
mid
=
(
left
+
right
)
//
2
if
fits_in_memory
(
mid
):
result
=
mid
left
=
mid
+
1
else
:
right
=
mid
-
1
return
result
def
check_enough_kv_cache_memory
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
):
...
...
@@ -487,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory
+=
layer_spec
.
max_memory_usage_bytes
(
vllm_config
)
if
needed_memory
>
available_memory
:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len
=
estimate_max_model_len
(
vllm_config
,
kv_cache_spec
,
available_memory
)
estimated_msg
=
""
if
estimated_max_len
>
0
:
estimated_msg
=
" Based on the available memory,"
f
" the estimated maximum model length is
{
estimated_max_len
}
."
raise
ValueError
(
f
"To serve at least one request with the models's max seq len "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB KV "
f
"(
{
max_model_len
}
), (
{
needed_memory
/
GiB_bytes
:.
2
f
}
GiB KV "
f
"cache is needed, which is larger than the available KV cache "
f
"memory (
{
available_memory
/
1024
/
1024
/
1024
:.
2
f
}
GiB). Try "
f
"increasing `gpu_memory_utilization` or decreasing "
f
"memory (
{
available_memory
/
GiB_bytes
:.
2
f
}
GiB)."
f
"
{
estimated_msg
}
"
f
" Try increasing `gpu_memory_utilization` or decreasing "
f
"`max_model_len` when initializing the engine."
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
9c4ecf15
...
...
@@ -7,7 +7,8 @@ from collections import deque
from
collections.abc
import
Iterable
from
typing
import
Optional
,
Union
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
(
CacheConfig
,
LoRAConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
...
@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_config
:
KVCacheConfig
,
structured_output_manager
:
StructuredOutputManager
,
speculative_config
:
SpeculativeConfig
=
None
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
include_finished_set
:
bool
=
False
,
log_stats
:
bool
=
False
,
...
...
@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
self
.
num_lookahead_tokens
=
0
if
speculative_config
and
speculative_config
.
method
==
"eagle"
:
self
.
num_lookahead_tokens
=
\
speculative_config
.
num_speculative_tokens
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
...
@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while
True
:
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
num_new_tokens
)
request
,
num_new_tokens
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
)
if
new_blocks
is
None
:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
...
...
@@ -505,8 +514,8 @@ class Scheduler(SchedulerInterface):
assert
mm_positions
is
not
None
assert
len
(
mm_positions
)
>
0
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
[
"
length
"
]
start_pos
=
pos_info
.
offset
num_encoder_tokens
=
pos_info
.
length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
...
@@ -522,6 +531,17 @@ class Scheduler(SchedulerInterface):
if
self
.
encoder_cache_manager
.
has_cache
(
request
,
i
):
# The encoder input is already computed and cached.
continue
# If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would
# only cover part of the mm input, roll back to before the mm item.
if
(
self
.
scheduler_config
.
disable_chunked_mm_input
and
num_computed_tokens
<
start_pos
and
(
num_computed_tokens
+
num_new_tokens
)
<
(
start_pos
+
num_encoder_tokens
)):
num_new_tokens
=
start_pos
-
num_computed_tokens
break
if
(
not
self
.
encoder_cache_manager
.
can_allocate
(
request
,
i
)
or
num_encoder_tokens
>
encoder_budget
):
# The encoder cache is full or the encoder budget is exhausted.
...
...
@@ -596,8 +616,8 @@ class Scheduler(SchedulerInterface):
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
mm_positions
=
request
.
mm_positions
[
input_id
]
start_pos
=
mm_positions
[
"
offset
"
]
num_tokens
=
mm_positions
[
"
length
"
]
start_pos
=
mm_positions
.
offset
num_tokens
=
mm_positions
.
length
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# in the decoder's KV cache.
...
...
vllm/v1/engine/__init__.py
View file @
9c4ecf15
...
...
@@ -2,6 +2,7 @@
import
enum
import
time
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
import
msgspec
...
...
@@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt
:
Optional
[
str
]
prompt_token_ids
:
list
[
int
]
mm_inputs
:
Optional
[
list
[
MultiModalKwargs
]]
mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]
]
mm_hashes
:
Optional
[
list
[
str
]]
mm_placeholders
:
Optional
[
list
[
PlaceholderRange
]]
sampling_params
:
SamplingParams
...
...
vllm/v1/engine/core.py
View file @
9c4ecf15
...
...
@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
M
MInputCacheServer
from
vllm.v1.engine.mm_input_cache
import
M
irroredProcessingCache
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -98,6 +98,7 @@ class EngineCore:
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
kv_cache_config
=
kv_cache_config
,
speculative_config
=
vllm_config
.
speculative_config
,
structured_output_manager
=
self
.
structured_output_manager
,
include_finished_set
=
vllm_config
.
parallel_config
.
data_parallel_size
>
1
,
...
...
@@ -105,7 +106,7 @@ class EngineCore:
)
# Setup MM Input Mapper.
self
.
mm_input_cache_server
=
M
MInputCacheServer
(
self
.
mm_input_cache_server
=
M
irroredProcessingCache
(
vllm_config
.
model_config
)
# Setup batch queue for pipeline parallelism.
...
...
@@ -173,7 +174,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here
# as well.
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
=
self
.
mm_input_cache_server
.
get_and_update
(
request
.
mm_inputs
=
self
.
mm_input_cache_server
.
get_and_update
_p1
(
request
.
mm_inputs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
@@ -486,14 +487,14 @@ class EngineCoreProc(EngineCore):
with
zmq_socket_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
while
True
:
# (RequestType, RequestData)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
type_frame
,
*
data_frame
s
=
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frame
.
buffer
)
request
=
decoder
.
decode
(
data_frame
s
)
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
...
...
@@ -510,8 +511,8 @@ class EngineCoreProc(EngineCore):
while
True
:
outputs
=
self
.
output_queue
.
get
()
outputs
.
engine_index
=
engine_index
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send
(
buffer
,
copy
=
False
)
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send
_multipart
(
buffer
s
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
...
...
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment