Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
91 deletions
+239
-91
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+3
-3
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+3
-3
vllm/platforms/interface.py
vllm/platforms/interface.py
+16
-2
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+27
-1
vllm/pooling_params.py
vllm/pooling_params.py
+21
-2
vllm/reasoning/granite_reasoning_parser.py
vllm/reasoning/granite_reasoning_parser.py
+1
-1
vllm/sampling_params.py
vllm/sampling_params.py
+5
-4
vllm/third_party/pynvml.py
vllm/third_party/pynvml.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-1
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+4
-1
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+3
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+7
-0
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+26
-11
vllm/utils.py
vllm/utils.py
+48
-15
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+3
-3
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+8
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+5
-5
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+8
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+41
-32
No files found.
vllm/platforms/cpu.py
View file @
31330101
...
...
@@ -69,12 +69,12 @@ class CpuPlatform(Platform):
cache_config
=
vllm_config
.
cache_config
ipex_ava
l
iable
=
find_spec
(
"intel_extension_for_pytorch"
)
is
not
None
ipex_avai
l
able
=
find_spec
(
"intel_extension_for_pytorch"
)
is
not
None
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
128
if
ipex_ava
l
iable
else
16
cache_config
.
block_size
=
128
if
ipex_avai
l
able
else
16
if
not
ipex_ava
l
iable
and
cache_config
.
block_size
!=
16
:
if
not
ipex_avai
l
able
and
cache_config
.
block_size
!=
16
:
raise
RuntimeError
(
f
"--block-size=
{
cache_config
.
block_size
}
requires"
" intel_extension_for_pytorch"
)
...
...
vllm/platforms/hpu.py
View file @
31330101
...
...
@@ -46,15 +46,15 @@ class HpuPlatform(Platform):
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
scheduler_config
=
vllm_config
.
scheduler_config
parallel_config
=
vllm_config
.
parallel_config
if
scheduler_config
.
is_multi_step
:
raise
NotImplementedError
(
"
M
ulti
-
step
execution is not implemented for HPU"
)
parallel_config
.
worker_cls
=
\
"
vllm.worker.m
ulti
_
step
_hpu_worker.MultiStepHPUWorker"
if
vllm_config
.
speculative_config
is
not
None
:
raise
NotImplementedError
(
"Speculative decoding is not implemented for HPU"
)
parallel_config
=
vllm_config
.
parallel_config
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
"vllm.worker.hpu_worker.HPUWorker"
...
...
vllm/platforms/interface.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
enum
import
platform
import
random
...
...
@@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import
numpy
as
np
import
torch
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
else
:
ModelConfig
=
None
VllmConfig
=
None
LoRARequest
=
None
PoolingParams
=
None
SamplingParams
=
None
FlexibleArgumentParser
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -231,7 +237,7 @@ class Platform:
parser
:
Optional
[
FlexibleArgumentParser
]
=
None
)
->
None
:
"""
Do some pre-regist
e
ration or update action for the current platform.
Do some pre-registration or update action for the current platform.
This function is called before global VllmConfig is initialized or cli
arguments are parsed. It's used for out-of-tree platforms to register or
...
...
@@ -386,6 +392,14 @@ class Platform:
"""
return
False
@
classmethod
def
validate_request
(
cls
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
"""Raises if this request is unsupported on this platform"""
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/tpu.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.pooling_params
import
PoolingParams
else
:
ModelConfig
=
None
VllmConfig
=
None
PoolingParams
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -116,6 +120,13 @@ class TpuPlatform(Platform):
assert
not
vllm_config
.
speculative_config
,
(
"Speculative decoding is not yet supported for TPU backend"
)
if
scheduler_config
.
is_multimodal_model
and
not
\
scheduler_config
.
disable_chunked_mm_input
:
logger
.
warning
(
"TPU does not support running Multimodal models"
\
" without setting `--disable_chunked_mm_input`. "
\
"Forcing --disable_chunked_mm_input."
)
scheduler_config
.
disable_chunked_mm_input
=
True
@
classmethod
def
is_pin_memory_available
(
cls
):
logger
.
warning
(
"Pin memory is not supported on TPU."
)
...
...
@@ -133,3 +144,18 @@ class TpuPlatform(Platform):
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
# V1 support on TPU is experimental
return
True
@
classmethod
def
validate_request
(
cls
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
"""Raises if this request is unsupported on this platform"""
if
isinstance
(
params
,
SamplingParams
):
if
params
.
guided_decoding
is
not
None
:
raise
ValueError
(
"Structured output is not supported on "
f
"
{
cls
.
device_name
}
."
)
if
params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
raise
ValueError
(
"Torch XLA does not support per-request seed."
)
vllm/pooling_params.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
class
PoolingParams
(
msgspec
.
Struct
,
...
...
@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions
:
Optional
[
int
]
=
None
additional_data
:
Optional
[
Any
]
=
None
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
dimensions
is
not
None
:
if
not
model_config
.
is_matryoshka
:
raise
ValueError
(
f
'Model "
{
model_config
.
served_model_name
}
" does not '
f
'support matryoshka representation, '
f
'changing output dimensions will lead to poor results.'
)
if
self
.
dimensions
<
1
:
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
f
"dimensions=
{
self
.
dimensions
}
, "
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
vllm/reasoning/granite_reasoning_parser.py
View file @
31330101
...
...
@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser):
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
request (ChatCompletionReq
u
est): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
...
...
vllm/sampling_params.py
View file @
31330101
...
...
@@ -101,7 +101,7 @@ class RequestOutputKind(Enum):
CUMULATIVE
=
0
# Return only deltas in each RequestOutput
DELTA
=
1
# Do not return intermediate RequestOuput
s
# Do not return intermediate RequestOu
t
put
FINAL_ONLY
=
2
...
...
@@ -385,9 +385,10 @@ class SamplingParams(
if
not
-
2.0
<=
self
.
frequency_penalty
<=
2.0
:
raise
ValueError
(
"frequency_penalty must be in [-2, 2], got "
f
"
{
self
.
frequency_penalty
}
."
)
if
not
0.0
<
self
.
repetition_penalty
<=
2.0
:
raise
ValueError
(
"repetition_penalty must be in (0, 2], got "
f
"
{
self
.
repetition_penalty
}
."
)
if
self
.
repetition_penalty
<=
0.0
:
raise
ValueError
(
"repetition_penalty must be greater than zero, got "
f
"
{
self
.
repetition_penalty
}
."
)
if
self
.
temperature
<
0.0
:
raise
ValueError
(
f
"temperature must be non-negative, got
{
self
.
temperature
}
."
)
...
...
vllm/third_party/pynvml.py
View file @
31330101
...
...
@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure):
e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"}
to produce nicer output.
Default fo
m
ratting string for all fields can be set with key "<default>" like:
Default for
m
atting string for all fields can be set with key "<default>" like:
_fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz.
If not set it's assumed to be just "%s"
...
...
vllm/transformers_utils/config.py
View file @
31330101
...
...
@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
def
get_hf_image_processor_config
(
model
:
Union
[
str
,
Path
],
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Dict
[
str
,
Any
]:
...
...
@@ -721,7 +722,10 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if
check_gguf_file
(
model
):
model
=
Path
(
model
).
parent
return
get_image_processor_config
(
model
,
revision
=
revision
,
**
kwargs
)
return
get_image_processor_config
(
model
,
token
=
hf_token
,
revision
=
revision
,
**
kwargs
)
def
get_hf_text_config
(
config
:
PretrainedConfig
):
...
...
vllm/transformers_utils/configs/eagle.py
View file @
31330101
...
...
@@ -5,6 +5,7 @@ from typing import Optional, Union
from
transformers
import
AutoConfig
,
PretrainedConfig
import
vllm.envs
as
envs
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
...
...
@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
self
.
truncated_vocab_size
=
self
.
model
.
vocab_size
if
\
truncated_vocab_size
is
None
else
truncated_vocab_size
if
"architectures"
not
in
kwargs
:
if
not
envs
.
VLLM_USE_V1
:
kwargs
[
"architectures"
]
=
[
"EAGLEModel"
]
else
:
kwargs
[
"architectures"
]
=
[
"EagleLlamaForCausalLM"
]
super
().
__init__
(
**
kwargs
)
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
truncate_tool_call_ids
,
validate_request_params
)
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
,
"validate_request_params"
]
vllm/transformers_utils/tokenizers/mistral.py
View file @
31330101
...
...
@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request
.
messages
[
i
][
"tool_call_id"
]
=
tool_call_id
def
validate_request_params
(
request
:
"ChatCompletionRequest"
):
if
(
request
.
skip_special_tokens
is
not
None
and
not
request
.
skip_special_tokens
):
raise
ValueError
(
"skip_special_tokens=False is not supported "
"for Mistral tokenizers."
)
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
repo_cache
=
os
.
path
.
join
(
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
...
...
vllm/transformers_utils/utils.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
json
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
...
...
@@ -51,6 +52,26 @@ def modelscope_list_repo_files(
return
files
def
_maybe_json_dict
(
path
:
Union
[
str
,
PathLike
])
->
dict
[
str
,
str
]:
with
open
(
path
)
as
f
:
try
:
return
json
.
loads
(
f
.
read
())
except
Exception
:
return
dict
[
str
,
str
]()
def
_maybe_space_split_dict
(
path
:
Union
[
str
,
PathLike
])
->
dict
[
str
,
str
]:
parsed_dict
=
dict
[
str
,
str
]()
with
open
(
path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
strip
().
split
()
parsed_dict
[
model_name
]
=
redirect_name
except
Exception
:
pass
return
parsed_dict
@
cache
def
maybe_model_redirect
(
model
:
str
)
->
str
:
"""
...
...
@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str:
if
not
Path
(
model_redirect_path
).
exists
():
return
model
with
open
(
model_redirect_path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
split
(
"
\t
"
)
if
model
==
model_name
:
redirect_name
=
redirect_name
.
strip
()
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_name
)
return
redirect_name
except
Exception
:
pass
redirect_dict
=
(
_maybe_json_dict
(
model_redirect_path
)
or
_maybe_space_split_dict
(
model_redirect_path
))
if
(
redirect_model
:
=
redirect_dict
.
get
(
model
)):
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_model
)
return
redirect_model
return
model
vllm/utils.py
View file @
31330101
...
...
@@ -2,7 +2,6 @@
from
__future__
import
annotations
import
argparse
import
asyncio
import
concurrent
import
contextlib
...
...
@@ -25,6 +24,7 @@ import socket
import
subprocess
import
sys
import
tempfile
import
textwrap
import
threading
import
time
import
traceback
...
...
@@ -32,6 +32,8 @@ import types
import
uuid
import
warnings
import
weakref
from
argparse
import
(
Action
,
ArgumentDefaultsHelpFormatter
,
ArgumentParser
,
ArgumentTypeError
)
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
...
...
@@ -40,7 +42,7 @@ from dataclasses import dataclass, field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
types
import
MappingProxyType
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
from
uuid
import
uuid4
import
cachetools
...
...
@@ -53,6 +55,7 @@ import torch.types
import
yaml
import
zmq
import
zmq.asyncio
from
packaging
import
version
from
packaging.version
import
Version
from
torch.library
import
Library
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
...
...
@@ -1209,7 +1212,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return
wrapper
class
StoreBoolean
(
argparse
.
Action
):
class
StoreBoolean
(
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
if
values
.
lower
()
==
"true"
:
...
...
@@ -1221,15 +1224,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
argparse
.
ArgumentDefaultsHelpFormatter
):
class
SortedHelpFormatter
(
ArgumentDefaultsHelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def
_split_lines
(
self
,
text
,
width
):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline
=
re
.
compile
(
r
"(?<!\n)\n(?!\n)\s*"
)
multiple_newlines
=
re
.
compile
(
r
"\n{2,}\s*"
)
text
=
single_newline
.
sub
(
' '
,
text
)
lines
=
re
.
split
(
multiple_newlines
,
text
)
return
sum
([
textwrap
.
wrap
(
line
,
width
)
for
line
in
lines
],
[])
def
add_arguments
(
self
,
actions
):
actions
=
sorted
(
actions
,
key
=
lambda
x
:
x
.
option_strings
)
super
().
add_arguments
(
actions
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
class
FlexibleArgumentParser
(
ArgumentParser
):
"""ArgumentParser that allows both underscore and dash in names."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -1280,11 +1296,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value
=
int
(
value
)
except
ValueError
:
msg
=
"Port must be an integer"
raise
argparse
.
ArgumentTypeError
(
msg
)
from
None
raise
ArgumentTypeError
(
msg
)
from
None
if
not
(
1024
<=
value
<=
65535
):
raise
argparse
.
ArgumentTypeError
(
"Port must be between 1024 and 65535"
)
raise
ArgumentTypeError
(
"Port must be between 1024 and 65535"
)
return
value
...
...
@@ -2060,12 +2075,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
list
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
list
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
tags
:
Tuple
[
torch
.
Tag
,
...]
=
(),
):
"""
`torch.library.custom_op` can have significant overhead because it
...
...
@@ -2104,7 +2120,7 @@ def direct_register_custom_op(
import
torch._custom_op.impl
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
my_lib
=
target_lib
or
vllm_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
define
(
op_name
+
schema_str
,
tags
=
tags
)
my_lib
.
impl
(
op_name
,
op_func
,
dispatch_key
=
dispatch_key
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
...
...
@@ -2689,3 +2705,20 @@ def sha256(input) -> int:
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try
:
torch_version
=
version
.
parse
(
str
(
torch
.
__version__
))
return
torch_version
>=
version
.
parse
(
target
)
except
Exception
:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
target
)
vllm/v1/attention/backends/flash_attn.py
100755 → 100644
View file @
31330101
...
...
@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
...
...
@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
t
ensor
,
block_table
:
torch
.
T
ensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
t
ensor
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
T
ensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
...
@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
block_indices
=
block_indices
.
flatten
()
.
clip
(
max
=
block_table
.
shape
[
1
]
-
1
)
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
...
...
vllm/v1/attention/backends/mla/common.py
View file @
31330101
...
...
@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention(
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head
`kv_b_proj` is [W_UK; W_UV] concat
e
nated per head
`q_b_proj` is [W_UQ; W_QR] concat
e
nated per head
`out_proj` is W_O
...
...
@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
...
...
vllm/v1/attention/backends/pallas.py
View file @
31330101
...
...
@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
PallasAttentionBackend
(
AttentionBackend
):
...
...
@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
use_irope
:
logger
.
warning_once
(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Paged attention Pallas kernel does "
"not support block-sparse attention."
)
...
...
vllm/v1/core/block_pool.py
View file @
31330101
...
...
@@ -67,11 +67,11 @@ class BlockPool:
Returns:
The cached block if it exists, or None.
"""
if
block_hash
in
self
.
cached_block_hash_to_block
:
first_block_id
=
list
(
self
.
cached_block_hash_to_block
[
block_hash
].
keys
())[
0
]
return
self
.
cached_block_hash_to_block
[
block_hash
][
first
_block
_id
]
return
None
cached_blocks
=
self
.
cached_block_hash_to_block
.
get
(
block_hash
)
if
not
cached_blocks
:
return
None
first_block_id
=
next
(
iter
(
cached
_block
s
))
return
cached_blocks
[
first_block_id
]
def
cache_full_blocks
(
self
,
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
31330101
...
...
@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
_
,
max_tokens_per_mm_item
=
max
(
max_tokens_by_modality_dict
.
items
(),
key
=
lambda
item
:
item
[
1
])
if
(
scheduler_config
.
disable_chunked_mm_input
and
max_tokens_per_mm_item
>
scheduler_config
.
max_num_batched_tokens
):
raise
ValueError
(
"Chunked MM input disabled but max_tokens_per_mm_item "
f
"(
{
max_tokens_per_mm_item
}
) is larger than max_num_batched_tokens"
f
" (
{
scheduler_config
.
max_num_batched_tokens
}
). Please increase "
"max_num_batched_tokens."
)
encoder_compute_budget
=
max
(
scheduler_config
.
max_num_encoder_input_tokens
,
max_tokens_per_mm_item
)
encoder_cache_size
=
max
(
scheduler_config
.
encoder_cache_size
,
...
...
vllm/v1/core/kv_cache_manager.py
View file @
31330101
...
...
@@ -126,44 +126,46 @@ class KVCacheManager:
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
self
.
prefix_cache_stats
.
requests
+=
1
if
request
.
sampling_params
.
prompt_logprobs
is
None
:
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
last_block_hash
=
None
# When the request requires prompt logprobs, we skip prefix caching.
if
request
.
sampling_params
.
prompt_logprobs
is
not
None
:
return
[],
0
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
if
len
(
block_hashes
)
*
self
.
block_size
==
request
.
num_tokens
:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash
=
block_hashes
.
pop
()
else
:
last_block_hash
=
None
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
block_hashes
.
append
(
last_block_hash
)
computed_blocks
=
(
self
.
specialized_manager
.
find_longest_cache_hit
(
block_hashes
))
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
self
.
prefix_cache_stats
.
queries
+=
len
(
block_hashes
)
self
.
prefix_cache_stats
.
hits
+=
len
(
computed_blocks
)
if
last_block_hash
is
not
None
:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes
.
append
(
last_block_hash
)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
else
:
# Skip cache hits for prompt logprobs
return
[],
0
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
return
computed_blocks
,
num_computed_tokens
def
allocate_slots
(
self
,
request
:
Request
,
num_tokens
:
int
,
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
new_computed_blocks
:
Optional
[
list
[
KVCacheBlock
]]
=
None
,
num_lookahead_tokens
:
int
=
0
,
)
->
Optional
[
list
[
KVCacheBlock
]]:
"""Add slots for a request with new tokens to append.
...
...
@@ -173,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
-----------------------------------------------------------------------
...
...
@@ -210,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
len
(
new_computed_blocks
)
*
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
,
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_computed_tokens
+
num_tokens
+
num_lookahead_tokens
,
self
.
block_size
)
num_new_blocks
=
(
num_required_blocks
-
len
(
req_blocks
)
-
len
(
new_computed_blocks
))
...
...
@@ -245,8 +251,11 @@ class KVCacheManager:
else
:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks
=
max
(
0
,
self
.
num_preallocate_blocks
-
num_lookahead_tokens
//
self
.
block_size
)
num_new_blocks
=
min
(
num_new_blocks
+
self
.
num_preallocate_blocks
,
num_new_blocks
+
num_preallocate_blocks
,
self
.
block_pool
.
get_num_free_blocks
(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
...
...
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment