Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1231 additions
and
306 deletions
+1231
-306
vllm/platforms/interface.py
vllm/platforms/interface.py
+16
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+52
-16
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+7
-1
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+5
-1
vllm/reasoning/abs_reasoning_parsers.py
vllm/reasoning/abs_reasoning_parsers.py
+47
-54
vllm/reasoning/deepseek_r1_reasoning_parser.py
vllm/reasoning/deepseek_r1_reasoning_parser.py
+67
-56
vllm/reasoning/granite_reasoning_parser.py
vllm/reasoning/granite_reasoning_parser.py
+362
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+16
-3
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+13
-2
vllm/transformers_utils/configs/skyworkr1v.py
vllm/transformers_utils/configs/skyworkr1v.py
+53
-0
vllm/transformers_utils/processors/deepseek_vl2.py
vllm/transformers_utils/processors/deepseek_vl2.py
+1
-1
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+8
-6
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+46
-3
vllm/utils.py
vllm/utils.py
+195
-73
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+236
-14
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+34
-41
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+37
-17
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+22
-13
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+12
-4
No files found.
vllm/platforms/interface.py
View file @
fcfc474d
...
...
@@ -12,9 +12,10 @@ import torch
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.utils
import
FlexibleArgumentParser
else
:
ModelConfig
=
None
VllmConfig
=
None
FlexibleArgumentParser
=
None
...
...
@@ -376,6 +377,20 @@ class Platform:
or
parallel_config
.
distributed_executor_backend
==
"external_launcher"
)
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return
False
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
"""
Returns if custom allreduce is supported on the current platform
"""
return
False
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/rocm.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
os
from
functools
import
lru_cache
,
wraps
from
functools
import
cache
,
lru_cache
,
wraps
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -12,15 +12,17 @@ from vllm.logger import init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
None
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
try
:
from
amdsmi
import
(
amdsmi_get_gpu_asic_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
AmdSmiException
,
amdsmi_topo_get_link_type
)
from
amdsmi
import
(
AmdSmiException
,
amdsmi_get_gpu_asic_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_topo_get_link_type
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from amdsmi with %r"
,
e
)
...
...
@@ -97,6 +99,26 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return
device_id
@
cache
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
,
sliding_window
:
int
)
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
ON_NAVI
=
"gfx1"
in
GPU_ARCH
ON_MI250_MI300
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
# rocm custom page attention not support on navi (gfx1*)
return
(
ON_MI250_MI300
and
not
ON_NAVI
and
(
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
device_name
:
str
=
"rocm"
...
...
@@ -135,22 +157,15 @@ class RocmPlatform(Platform):
@
classmethod
@
lru_cache
(
maxsize
=
8
)
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
Optional
[
DeviceCapability
]:
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
device_id
)
return
DeviceCapability
(
major
=
major
,
minor
=
minor
)
@
classmethod
@
with_amdsmi_context
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
handle
=
amdsmi_get_processor_handles
()[
physical_device_id
]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return
torch
.
cuda
.
get_device_name
(
device_id
)
@
staticmethod
def
is_fully_connected_nvlink_or_xgmi
(
physical_device_ids
:
List
[
int
])
->
bool
:
@
with_amdsmi_context
def
is_fully_connected
(
physical_device_ids
:
List
[
int
])
->
bool
:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
...
...
@@ -172,6 +187,15 @@ class RocmPlatform(Platform):
return
False
return
True
@
classmethod
@
with_amdsmi_context
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
handle
=
amdsmi_get_processor_handles
()[
physical_device_id
]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return
torch
.
cuda
.
get_device_name
(
device_id
)
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
device_props
=
torch
.
cuda
.
get_device_properties
(
device_id
)
...
...
@@ -276,3 +300,15 @@ class RocmPlatform(Platform):
return
torch
.
float8_e4m3fnuz
else
:
return
torch
.
float8_e4m3fn
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
# V1 support on AMD gpus is experimental
return
True
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
# We only enable custom allreduce for MI300 series
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
supported_archs
=
[
'gfx94'
]
return
any
(
gfx
in
gcn_arch
for
gfx
in
supported_archs
)
vllm/platforms/tpu.py
View file @
fcfc474d
...
...
@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
None
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -127,3 +128,8 @@ class TpuPlatform(Platform):
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
# V1 support on TPU is experimental
return
True
vllm/
entrypoints/openai/
reasoning
_parsers
/__init__.py
→
vllm/reasoning/__init__.py
View file @
fcfc474d
...
...
@@ -2,7 +2,11 @@
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
__all__
=
[
"ReasoningParser"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
"ReasoningParser"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
,
"GraniteReasoningParser"
,
]
vllm/
entrypoints/openai/
reasoning
_parsers
/abs_reasoning_parsers.py
→
vllm/reasoning/abs_reasoning_parsers.py
View file @
fcfc474d
...
...
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
class
ReasoningParser
:
"""
Abstract reasoning parser class that should not be used directly.
Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
...
...
@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return
self
.
model_tokenizer
.
get_vocab
()
@
abstractmethod
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@
abstractmethod
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@
abstractmethod
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
...
...
@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!"
)
@
abstractmethod
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
...
...
@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!"
)
class
ReasoningParserManager
:
...
...
@@ -125,14 +115,16 @@ class ReasoningParserManager:
if
name
in
cls
.
reasoning_parsers
:
return
cls
.
reasoning_parsers
[
name
]
raise
KeyError
(
f
"reasoning helper: '
{
name
}
' not found in "
"
reasoning_parsers"
)
raise
KeyError
(
f
"reasoning helper: '
{
name
}
' not found in
reasoning_parsers"
)
@
classmethod
def
_register_module
(
cls
,
module
:
type
,
module_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
)
->
None
:
def
_register_module
(
cls
,
module
:
type
,
module_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
,
)
->
None
:
if
not
issubclass
(
module
,
ReasoningParser
):
raise
TypeError
(
"module must be subclass of ReasoningParser, "
f
"but got
{
type
(
module
)
}
"
)
...
...
@@ -149,13 +141,14 @@ class ReasoningParserManager:
@
classmethod
def
register_module
(
cls
,
name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
,
module
:
Union
[
type
,
None
]
=
None
)
->
Union
[
type
,
Callable
]:
cls
,
name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
,
module
:
Union
[
type
,
None
]
=
None
,
)
->
Union
[
type
,
Callable
]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
decoder(with module as None) or normal function(with module as not
None).
"""
if
not
isinstance
(
force
,
bool
):
...
...
@@ -183,7 +176,7 @@ class ReasoningParserManager:
@
classmethod
def
import_reasoning_parser
(
cls
,
plugin_path
:
str
)
->
None
:
"""
Import a user-defined reasoning parser by the path
Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
plugin_path
))[
0
]
...
...
vllm/
entrypoints/openai/
reasoning
_parsers
/deepseek_r1_reasoning_parser.py
→
vllm/reasoning/deepseek_r1_reasoning_parser.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
...
...
@@ -8,9 +7,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers
import
(
ReasoningParser
,
ReasoningParserManager
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
...
...
@@ -20,43 +18,42 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
start_token_id
:
int
end_token_id
:
int
start_token
:
str
=
"<think>"
end_token
:
str
=
"</think>"
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_start_token
=
"<think>"
self
.
think_end_token
=
"</think>"
self
.
reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_token
}
(.*?)
{
self
.
think_end_token
}
"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
think_start_token_id
=
self
.
vocab
.
get
(
self
.
think_start_token
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
(
self
.
think_start_token_id
is
None
or
self
.
think_end_token_id
is
None
):
self
.
start_token_id
=
self
.
vocab
.
get
(
self
.
start_token
)
self
.
end_token_id
=
self
.
vocab
.
get
(
self
.
end_token
)
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
raise
RuntimeError
(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
# TODO: need to rebase by PR #14428
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_
end_token_id
in
input_ids
return
self
.
end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract the content after the end tokens
"""
if
self
.
think_
end_token_id
not
in
input_ids
[:
-
1
]:
if
self
.
end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_
end_token_id
)
+
1
:]
return
input_ids
[
input_ids
.
index
(
self
.
end_token_id
)
+
1
:]
def
extract_reasoning_content_streaming
(
self
,
...
...
@@ -77,22 +74,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if
len
(
delta_token_ids
)
==
1
and
(
delta_token_ids
[
0
]
in
[
self
.
think_
start_token_id
,
self
.
think_
end_token_id
self
.
start_token_id
,
self
.
end_token_id
]):
return
None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if
self
.
think_
start_token_id
in
previous_token_ids
:
if
self
.
think_
end_token_id
in
delta_token_ids
:
if
self
.
start_token_id
in
previous_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
elif
self
.
end_token_id
in
previous_token_ids
:
# <think> in previous, </think> in previous,
# reasoning content continues
return
DeltaMessage
(
content
=
delta_text
)
...
...
@@ -100,17 +99,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_
start_token_id
in
delta_token_ids
:
if
self
.
think_
end_token_id
in
delta_token_ids
:
elif
self
.
start_token_id
in
delta_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_
start_token
)
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
start_index
=
delta_text
.
find
(
self
.
start_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[
start_index
+
len
(
self
.
think_start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
len
(
self
.
start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
else
:
# <think> in delta, no </think> in delta,
# reasoning content continues
...
...
@@ -119,15 +119,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_
end_token_id
in
delta_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
elif
self
.
end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
return
DeltaMessage
(
content
=
delta_text
)
else
:
...
...
@@ -137,25 +139,34 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""
Extract reasoning content from the model output.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the start token is present in the model output, remove it
# if it is present.
model_output_parts
=
model_output
.
partition
(
self
.
start_token
)
model_output
=
model_output_parts
[
2
]
if
model_output_parts
[
1
]
else
model_output_parts
[
0
]
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_
end_token
not
in
model_output
:
if
self
.
end_token
not
in
model_output
:
return
model_output
,
None
else
:
# Add a start token if it's missing to keep compatibility.
if
self
.
think_start_token
not
in
model_output
:
model_output
=
f
"
{
self
.
think_start_token
}{
model_output
}
"
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
end_index
=
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
)
final_output
=
model_output
[
end_index
:]
if
len
(
final_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
final_output
reasoning_content
,
_
,
content
=
model_output
.
partition
(
self
.
end_token
)
# If the end token is not found, return the model output as is.
# It should not happen since we already checked for the presence
# of the end token.
# If generation stops right after end-of-think, return null content
final_content
=
content
or
None
return
reasoning_content
,
final_content
vllm/reasoning/granite_reasoning_parser.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"granite"
)
class
GraniteReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for IBM Granite.
IBM granite models currently use "Here is my thought process:"
and "Here is my response:" to separate its thinking / response outputs.
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
# NOTE: There have been some observed occurrences of quantized
# instances of the current models using "Here's" instead of "Here is",
# so to be safe, we match on both.
self
.
think_start_expr
=
r
"(?:Here's|Here is) my thought process:"
self
.
response_start_expr
=
r
"(?:Here's|Here is) my response:"
self
.
reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_expr
}
(.*?)
{
self
.
response_start_expr
}
(.*)"
,
re
.
DOTALL
)
self
.
valid_think_starts
=
[
"Here's my thought process:"
,
"Here is my thought process:"
]
self
.
valid_response_starts
=
[
"Here's my response:"
,
"Here is my response:"
]
# Substrings to match for sequence boundaries on raw text
self
.
seq_boundary_end
=
":"
self
.
seq_boundary_start
=
"Here"
# The longest any thinking / start of response message can be
self
.
longest_think_start
=
max
(
len
(
think_start
)
for
think_start
in
self
.
valid_think_starts
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match
=
self
.
reasoning_regex
.
findall
(
model_output
)
if
not
re_match
:
return
None
,
model_output
reasoning_content
,
response_content
=
re_match
[
0
]
if
not
response_content
:
return
reasoning_content
,
None
return
reasoning_content
,
response_content
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""Extract the reasoning content / content emitted by granite models;
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
NOTE: Granite models do not use a special token to start their reasoning
and response sections; instead they have token sequences, e.g.,
Here is my thought process: Foo Here is my response: Bar
This increases the complexity of correctly handling streams, since we
need to watch for specific sequences and correctly parse them without
dropping content that is potentially overlapping & spanning multiple
delta messages.
Args:
previous_text (str): Previous text outside of this delta message.
current_text (str): Previous text + delta text.
delta_text (str): Text to consider and parse content from.
previous_token_ids (Sequence[int]): Token IDs of previous_text.
current_token_ids (Sequence[int]): Token IDs of current_text.
delta_token_ids (Sequence[int]): Token IDs of delta_text.
Returns:
Union[DeltaMessage, None]
DeltaMessage with either reasoning content or content, or None.
"""
reasoning_content
,
resp_seq_len
,
content
=
self
.
_get_content_sections
(
current_text
)
# Either we haven't finished the start of the reasoning sequence,
# or the model is generating something unexpected.
if
not
reasoning_content
:
delta_message
=
self
.
_get_delta_message_with_no_reasoning_bounds
(
current_text
,
delta_text
)
# We have a start of reasoning message, but have not yet finished
# the start of response sequence.
elif
not
content
:
delta_message
=
self
.
_get_delta_message_with_no_response_bounds
(
current_text
,
reasoning_content
,
delta_text
)
# We've finished both the start of reasoning and start of response seq.
else
:
# This should never happen since we matched on the response
assert
resp_seq_len
is
not
None
delta_message
=
self
.
_get_delta_message_with_both_bounds
(
delta_text
,
reasoning_content
,
content
,
current_text
,
resp_seq_len
)
if
not
delta_message
.
content
and
not
delta_message
.
reasoning_content
:
return
None
return
delta_message
#### Implementation details of stream parsing for granite models
def
_is_reasoning_start_substr
(
self
,
text
:
str
)
->
bool
:
"""Check if a text matches one of the possible start reasoning seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible reasoning start seqs match.
"""
return
any
(
think_start
.
startswith
(
text
)
for
think_start
in
self
.
valid_think_starts
)
def
_is_response_start_substr
(
self
,
text
:
str
)
->
bool
:
"""Check if a text matches one of the possible start response seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible response start seqs match.
"""
return
any
(
response_start
.
startswith
(
text
)
for
response_start
in
self
.
valid_response_starts
)
def
_get_delta_message_with_no_reasoning_bounds
(
self
,
current_text
:
str
,
delta_text
:
str
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has not yet completed
its start of reasoning sequence.
Args:
current_text (str): The full previous + delta text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
prev_longest_length
=
len
(
current_text
)
-
len
(
delta_text
)
is_substr
=
self
.
_is_reasoning_start_substr
(
current_text
)
was_substr
=
self
.
_is_reasoning_start_substr
(
current_text
[:
prev_longest_length
])
# Check if we just generated something NOT in the special token seq;
# if so, add everything that we previously skipped with this delta
# message and append everything to content in the future.
if
was_substr
and
not
is_substr
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
current_text
,
)
if
is_substr
:
# Might still be in the special token sequence; return nothing
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
# Otherwise the sequence has already been broken and we already
# corrected; just return the delta text as normal content.
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
delta_text
)
def
_get_delta_message_with_no_response_bounds
(
self
,
current_text
:
str
,
reasoning_content
:
str
,
delta_text
:
str
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has both reasoning
content with no (response) content. NOTE that we may have overlapping
tokens with the start of reasoning / start of response sequences on
either side of the delta text.
Args:
current_text (str): The full previous + delta text.
reasoning_content (str): reasoning content from current_text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# If we have no reasoning content or explicitly end with the start of
# response sequence, we are in transition to the response; need to be
# careful here, since the final token (:) will match the reasoning
# content and fully parse it out; we should not pass the : back.
ends_with_start_response_seq
=
any
(
current_text
.
endswith
(
response_start
)
for
response_start
in
self
.
valid_response_starts
)
if
reasoning_content
is
None
or
ends_with_start_response_seq
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
# Consider previous / current text only within context of the reasoning
previous_text
=
reasoning_content
[:
-
len
(
delta_text
)]
current_text
=
reasoning_content
# We need to be careful about adding unfinished response sequences;
# Find the place at which we MIGHT be starting a response sequence
prev_idx
=
previous_text
.
rfind
(
self
.
seq_boundary_start
)
delta_idx
=
delta_text
.
rfind
(
self
.
seq_boundary_start
)
# Check the state of potential start of response substring matches.
prev_was_substr
=
self
.
_is_response_start_substr
(
previous_text
[
prev_idx
:])
if
prev_idx
>=
0
else
False
delta_continues_substr
=
self
.
_is_response_start_substr
(
current_text
[
prev_idx
:])
if
prev_idx
>=
0
else
False
delta_new_substr
=
self
.
_is_response_start_substr
(
delta_text
[
delta_idx
:])
if
delta_idx
>=
0
else
False
# Delta only contains potential continued response sequence text.
if
delta_continues_substr
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
if
not
prev_was_substr
:
# Delta may be starting a new response seq but has other text too.
if
delta_new_substr
:
return
DeltaMessage
(
reasoning_content
=
delta_text
[:
delta_idx
],
content
=
None
)
# Normal case for most reasoning text (no potential special seqs).
return
DeltaMessage
(
reasoning_content
=
delta_text
,
content
=
None
)
# The substring that previously seemed to be a potential response
# seq wasn't one; we need to add the content to the delta message,
# and also slice off the potential response sequence
elif
delta_new_substr
:
reasoning_content
=
previous_text
[
prev_idx
:]
+
delta_text
[:
delta_idx
]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
None
)
# No new substring yet, and we broke our old one; take the whole delta
return
DeltaMessage
(
reasoning_content
=
previous_text
[
prev_idx
:]
+
delta_text
,
content
=
None
,
)
def
_get_delta_message_with_both_bounds
(
self
,
delta_text
:
str
,
reasoning_content
:
str
,
response_content
:
str
,
current_text
:
str
,
response_seq_len
:
int
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has both reasoning
content and normal (response) content.
Args:
delta_text (str): Text to consider and parse content from.
reasoning_content (str): reasoning content from current_text.
response_content (str): response content from current_text.
current_text (str): The full previous + delta text.
response_seq_len(str): Len of the complete response sequence used.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# Always have content; take length to the end
delta_content
=
delta_text
[
-
len
(
response_content
):]
reasoning_end_idx
=
len
(
delta_text
)
-
(
len
(
response_content
)
+
response_seq_len
)
if
reasoning_end_idx
<
0
:
delta_reasoning_content
=
None
else
:
# Get the starting offset
start_reasoning_content_idx
=
len
(
reasoning_content
)
+
response_seq_len
+
len
(
response_content
)
-
1
delta_offset
=
len
(
current_text
)
-
len
(
delta_text
)
start_offset
=
start_reasoning_content_idx
-
delta_offset
if
start_offset
<
0
:
start_offset
=
0
delta_reasoning_content
=
delta_text
[
start_offset
:
reasoning_end_idx
]
return
DeltaMessage
(
reasoning_content
=
delta_reasoning_content
,
content
=
delta_content
,
)
def
_get_content_sections
(
self
,
current_text
:
str
)
->
tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
str
]]:
"""Parse the text to extract the reasoning content / content
if we have them.
Args:
current_text (str): The full previous + delta text.
Returns:
tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
containing the reasoning content, the length of the response seq
(if there is one) and the non-reasoning content.
"""
current_chunk_start
=
0
start_reasoning_content
=
None
parsed_content
=
False
delimiter_idxs
=
[
idx
for
idx
,
char
in
enumerate
(
current_text
)
if
char
==
self
.
seq_boundary_end
]
for
current_chunk_end
in
delimiter_idxs
:
current_chunk
=
current_text
[
current_chunk_start
:
current_chunk_end
]
# Check to see if the start of reasoning seq if complete
if
start_reasoning_content
is
None
:
for
think_start
in
self
.
valid_think_starts
:
if
current_chunk
==
think_start
[:
-
1
]:
start_reasoning_content
=
current_chunk_end
+
1
current_chunk_start
=
current_chunk_end
+
1
break
# Check to see if the start of response seq if complete
elif
not
parsed_content
:
for
response_start
in
self
.
valid_response_starts
:
if
current_chunk
[
-
len
(
response_start
)
+
1
:]
==
response_start
[:
-
1
]:
# Mark end of reasoning and start response content
# after the start of response sequence.
end_reasoning_content
=
current_chunk_end
-
len
(
response_start
)
reasoning_content
=
current_text
[
start_reasoning_content
:
end_reasoning_content
]
response_content
=
current_text
[
current_chunk_end
+
1
:]
return
reasoning_content
,
len
(
response_start
),
response_content
if
start_reasoning_content
and
not
parsed_content
:
return
current_text
[
start_reasoning_content
:],
None
,
None
return
None
,
None
,
None
vllm/transformers_utils/config.py
View file @
fcfc474d
...
...
@@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
Olmo2Config
,
RWConfig
,
S
olarConfig
,
Telechat2
Config
,
UltravoxConfig
)
S
kyworkR1VChatConfig
,
Solar
Config
,
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
...
...
@@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D"
:
NVLM_D_Config
,
"olmo2"
:
Olmo2Config
,
"solar"
:
SolarConfig
,
"skywork_chat"
:
SkyworkR1VChatConfig
,
"telechat"
:
Telechat2Config
,
"ultravox"
:
UltravoxConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
...
...
@@ -261,6 +262,11 @@ def get_config(
MISTRAL_CONFIG_NAME
,
revision
=
revision
):
config_format
=
ConfigFormat
.
MISTRAL
else
:
raise
ValueError
(
"Could not detect config format for no config file found. "
"Ensure your model has either config.json (HF format) "
"or params.json (Mistral format)."
)
except
Exception
as
e
:
error_message
=
(
...
...
@@ -323,7 +329,14 @@ def get_config(
elif
config_format
==
ConfigFormat
.
MISTRAL
:
config
=
load_params_config
(
model
,
revision
,
token
=
HF_TOKEN
,
**
kwargs
)
else
:
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
"
)
supported_formats
=
[
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
]
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
. "
f
"Supported formats are:
{
', '
.
join
(
supported_formats
)
}
. "
f
"Ensure your model uses one of these configuration formats "
f
"or specify the correct format explicitly."
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
...
...
vllm/transformers_utils/configs/__init__.py
View file @
fcfc474d
...
...
@@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from
vllm.transformers_utils.configs.nemotron
import
NemotronConfig
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.olmo2
import
Olmo2Config
from
vllm.transformers_utils.configs.skyworkr1v
import
SkyworkR1VChatConfig
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.telechat2
import
Telechat2Config
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
...
...
@@ -42,6 +43,7 @@ __all__ = [
"NemotronConfig"
,
"NVLM_D_Config"
,
"Olmo2Config"
,
"SkyworkR1VChatConfig"
,
"SolarConfig"
,
"Telechat2Config"
,
"UltravoxConfig"
,
...
...
vllm/transformers_utils/configs/eagle.py
View file @
fcfc474d
...
...
@@ -5,6 +5,8 @@ from typing import Optional, Union
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
class
EAGLEConfig
(
PretrainedConfig
):
model_type
=
"eagle"
...
...
@@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
truncated_vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
model_config
=
None
if
model
is
None
else
(
AutoConfig
.
for_model
(
**
model
)
if
isinstance
(
model
,
dict
)
else
model
)
model_config
:
Union
[
PretrainedConfig
,
DeepseekV2Config
,
None
]
if
isinstance
(
model
,
dict
):
archs
=
model
.
get
(
"architectures"
,
[])
target_archs
=
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]
if
any
(
target_arch
in
archs
for
target_arch
in
target_archs
):
# AutoConfig does not support DeepSeek MoE models yet
model_config
=
DeepseekV2Config
(
**
model
)
else
:
model_config
=
AutoConfig
.
for_model
(
**
model
)
else
:
model_config
=
model
for
k
,
v
in
kwargs
.
items
():
if
k
!=
"architectures"
and
k
!=
"model_type"
and
hasattr
(
...
...
vllm/transformers_utils/configs/skyworkr1v.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
transformers.configuration_utils
import
PretrainedConfig
class
SkyworkR1VChatConfig
(
PretrainedConfig
):
model_type
=
'internvl_chat'
is_composition
=
True
def
__init__
(
self
,
vision_config
=
None
,
llm_config
=
None
,
use_backbone_lora
=
0
,
use_llm_lora
=
0
,
select_layer
=-
1
,
force_image_size
=
None
,
downsample_ratio
=
0.5
,
template
=
None
,
dynamic_image_size
=
False
,
use_thumbnail
=
False
,
ps_version
=
'v1'
,
min_dynamic_patch
=
1
,
max_dynamic_patch
=
6
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
vision_config
is
None
:
vision_config
=
{}
if
llm_config
is
None
:
llm_config
=
{}
self
.
vision_config
=
PretrainedConfig
(
**
vision_config
)
self
.
text_config
=
PretrainedConfig
(
**
llm_config
)
self
.
use_backbone_lora
=
use_backbone_lora
self
.
use_llm_lora
=
use_llm_lora
self
.
select_layer
=
select_layer
self
.
force_image_size
=
force_image_size
self
.
downsample_ratio
=
downsample_ratio
self
.
template
=
template
self
.
dynamic_image_size
=
dynamic_image_size
self
.
use_thumbnail
=
use_thumbnail
self
.
ps_version
=
ps_version
# pixel shuffle version
self
.
min_dynamic_patch
=
min_dynamic_patch
self
.
max_dynamic_patch
=
max_dynamic_patch
vllm/transformers_utils/processors/deepseek_vl2.py
View file @
fcfc474d
...
...
@@ -226,7 +226,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
input_ids
[
input_ids
<
0
]
=
self
.
pad_id
if
inference_mode
:
#
去掉结尾的
eos token
#
Remove the ending
eos token
assert
input_ids
[
-
1
]
==
self
.
eos_id
input_ids
=
input_ids
[:
-
1
]
target_ids
=
target_ids
[:
-
1
]
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
fcfc474d
...
...
@@ -124,13 +124,15 @@ def find_tokenizer_file(files: List[str]):
matched_files
=
[
file
for
file
in
files
if
file_pattern
.
match
(
file
)]
if
len
(
matched_files
)
>
1
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern:
{
file_pattern
}
. Make sure only one Mistral "
f
"tokenizer is present in
{
files
}
."
)
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern: `
{
file_pattern
.
pattern
}
`. Make sure only one Mistral "
f
"tokenizer is present in
{
files
}
."
)
elif
len
(
matched_files
)
==
0
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern:
{
file_pattern
}
. Make sure that a Mistral "
f
"tokenizer is present in
{
files
}
."
)
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern: `
{
file_pattern
.
pattern
}
`. Make sure that a Mistral "
f
"tokenizer is present in
{
files
}
."
)
return
matched_files
[
0
]
...
...
vllm/transformers_utils/utils.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
from
vllm.envs
import
VLLM_MODEL_REDIRECT_PATH
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
is_s3
(
model_or_path
:
str
)
->
bool
:
return
model_or_path
.
lower
().
startswith
(
's3://'
)
...
...
@@ -17,9 +23,14 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool:
elif
model
.
suffix
==
".gguf"
:
return
True
with
open
(
model
,
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
try
:
with
model
.
open
(
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
except
Exception
as
e
:
logger
.
debug
(
"Error reading file %s: %s"
,
model
,
e
)
return
False
def
modelscope_list_repo_files
(
...
...
@@ -38,3 +49,35 @@ def modelscope_list_repo_files(
if
file
[
'Type'
]
==
'blob'
]
return
files
@
cache
def
maybe_model_redirect
(
model
:
str
)
->
str
:
"""
Use model_redirect to redirect the model name to a local folder.
:param model: hf model name
:return: maybe redirect to a local folder
"""
model_redirect_path
=
VLLM_MODEL_REDIRECT_PATH
if
not
model_redirect_path
:
return
model
if
not
Path
(
model_redirect_path
).
exists
():
return
model
with
open
(
model_redirect_path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
split
(
"
\t
"
)
if
model
==
model_name
:
redirect_name
=
redirect_name
.
strip
()
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_name
)
return
redirect_name
except
Exception
:
pass
return
model
vllm/utils.py
View file @
fcfc474d
...
...
@@ -10,6 +10,7 @@ import datetime
import
enum
import
gc
import
getpass
import
hashlib
import
importlib
import
importlib.metadata
import
importlib.util
...
...
@@ -17,6 +18,7 @@ import inspect
import
ipaddress
import
multiprocessing
import
os
import
pickle
import
re
import
signal
import
socket
...
...
@@ -31,15 +33,17 @@ import uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
Mapping
)
Iterable
,
Iterator
,
KeysView
,
Mapping
)
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
types
import
MappingProxyType
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
Type
,
TypeVar
,
Union
)
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
from
uuid
import
uuid4
import
cachetools
import
cloudpickle
import
numpy
as
np
import
numpy.typing
as
npt
...
...
@@ -57,7 +61,7 @@ import vllm.envs as envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
logger
=
init_logger
(
__name__
)
import
json
...
...
@@ -172,6 +176,7 @@ U = TypeVar("U")
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
_T
=
TypeVar
(
"_T"
)
class
_Sentinel
:
...
...
@@ -205,6 +210,19 @@ class Counter:
self
.
counter
=
0
class
_MappingOrderCacheView
(
UserDict
[
_K
,
_V
]):
def
__init__
(
self
,
data
:
Mapping
[
_K
,
_V
],
ordered_keys
:
Mapping
[
_K
,
None
]):
super
().
__init__
(
data
)
self
.
ordered_keys
=
ordered_keys
def
__iter__
(
self
)
->
Iterator
[
_K
]:
return
iter
(
self
.
ordered_keys
)
def
keys
(
self
)
->
KeysView
[
_K
]:
return
KeysView
(
self
.
ordered_keys
)
class
CacheInfo
(
NamedTuple
):
hits
:
int
total
:
int
...
...
@@ -217,45 +235,62 @@ class CacheInfo(NamedTuple):
return
self
.
hits
/
self
.
total
class
LRUCache
(
Generic
[
_K
,
_V
]):
"""Note: This class is not thread safe!"""
class
LRUCache
(
cachetools
.
LRUCache
[
_K
,
_V
],
Generic
[
_K
,
_V
]):
def
__init__
(
self
,
capacity
:
int
)
->
None
:
self
.
cache
=
OrderedDict
[
_K
,
_V
]()
def
__init__
(
self
,
capacity
:
float
,
getsizeof
:
Optional
[
Callable
[[
_V
],
float
]]
=
None
):
super
().
__init__
(
capacity
,
getsizeof
)
self
.
pinned_items
=
set
[
_K
]()
self
.
capacity
=
capacity
self
.
_hits
=
0
self
.
_total
=
0
def
__contains__
(
self
,
key
:
_K
)
->
bool
:
return
key
in
self
.
cache
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
_K
)
->
_V
:
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
self
.
cache
.
move_to_end
(
key
)
return
value
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
run_on_remove
=
key
in
self
value
=
self
.
__getitem__
(
key
)
super
().
__delitem__
(
key
)
if
key
in
self
.
pinned_items
:
# Todo: add warning to inform that del pinned item
self
.
_unpin
(
key
)
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
def
__setitem__
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
put
(
key
,
value
)
@
property
def
cache
(
self
)
->
Mapping
[
_K
,
_V
]:
"""Return the internal cache dictionary in order (read-only)."""
return
_MappingOrderCacheView
(
self
.
_Cache__data
,
# type: ignore
self
.
order
)
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
self
.
pop
(
key
)
@
property
def
order
(
self
)
->
Mapping
[
_K
,
None
]:
"""Return the internal order dictionary (read-only)."""
return
MappingProxyType
(
self
.
_LRUCache__order
)
# type: ignore
def
stat
(
self
)
->
CacheInfo
:
return
CacheInfo
(
hits
=
self
.
_hits
,
total
=
self
.
_total
)
def
touch
(
self
,
key
:
_K
)
->
None
:
self
.
cache
.
move_to_end
(
key
)
self
.
_LRUCache__update
(
key
)
# type: ignore
@
overload
def
get
(
self
,
key
:
_K
,
/
)
->
Optional
[
_V
]:
...
def
get
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
value
:
Optional
[
_V
]
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
@
overload
def
get
(
self
,
key
:
_K
,
/
,
default
:
Union
[
_V
,
_T
])
->
Union
[
_V
,
_T
]:
...
def
get
(
self
,
key
:
_K
,
/
,
default
:
Optional
[
Union
[
_V
,
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
value
:
Optional
[
Union
[
_V
,
_T
]]
if
key
in
self
:
value
=
self
.
__getitem__
(
key
)
self
.
_hits
+=
1
else
:
...
...
@@ -264,60 +299,76 @@ class LRUCache(Generic[_K, _V]):
self
.
_total
+=
1
return
value
@
overload
def
pop
(
self
,
key
:
_K
)
->
_V
:
...
@
overload
def
pop
(
self
,
key
:
_K
,
default
:
Union
[
_V
,
_T
])
->
Union
[
_V
,
_T
]:
...
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
Union
[
_V
,
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
value
:
Optional
[
Union
[
_V
,
_T
]]
if
key
not
in
self
:
return
default
value
=
self
[
key
]
del
self
[
key
]
return
value
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
self
.
__setitem__
(
key
,
value
)
def
pin
(
self
,
key
:
_K
)
->
None
:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if
key
not
in
self
.
cache
:
if
key
not
in
self
:
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
self
.
pinned_items
.
add
(
key
)
def
_unpin
(
self
,
key
:
_K
)
->
None
:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self
.
pinned_items
.
remove
(
key
)
def
_on_remove
(
self
,
key
:
_K
,
value
:
Optional
[
_V
])
->
None
:
pass
def
remove_oldest
(
self
,
*
,
remove_pinned
:
bool
=
False
)
->
None
:
if
not
self
.
cache
:
if
len
(
self
)
==
0
:
return
self
.
popitem
(
remove_pinned
=
remove_pinned
)
def
_remove_old_if_needed
(
self
)
->
None
:
while
self
.
currsize
>
self
.
capacity
:
self
.
remove_oldest
()
def
clear
(
self
)
->
None
:
while
len
(
self
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
def
popitem
(
self
,
remove_pinned
:
bool
=
False
):
"""Remove and return the `(key, value)` pair least recently used."""
if
not
remove_pinned
:
# pop the oldest item in the cache that is not pinned
lru_key
=
next
(
(
key
for
key
in
self
.
cache
if
key
not
in
self
.
pinned_items
),
(
key
for
key
in
self
.
order
if
key
not
in
self
.
pinned_items
),
ALL_PINNED_SENTINEL
)
if
lru_key
is
ALL_PINNED_SENTINEL
:
raise
RuntimeError
(
"All items are pinned, "
"cannot remove oldest from the cache."
)
else
:
lru_key
=
next
(
iter
(
self
.
cache
))
self
.
pop
(
lru_key
)
# type: ignore
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
run_on_remove
=
key
in
self
.
cache
value
=
self
.
cache
.
pop
(
key
,
default
)
# remove from pinned items
if
key
in
self
.
pinned_items
:
self
.
_unpin
(
key
)
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
return
value
def
clear
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
cache
.
clear
()
lru_key
=
next
(
iter
(
self
.
order
))
value
=
self
.
pop
(
cast
(
_K
,
lru_key
))
return
(
lru_key
,
value
)
class
PyObjectCache
:
...
...
@@ -528,7 +579,7 @@ def get_open_port() -> int:
dp_port
=
envs
.
VLLM_DP_MASTER_PORT
while
True
:
port
=
_get_open_port
()
if
port
>=
dp_port
and
port
<
dp_port
+
10
:
if
dp_port
<=
port
<
dp_port
+
10
:
continue
return
port
return
_get_open_port
()
...
...
@@ -745,6 +796,14 @@ def is_pin_memory_available() -> bool:
return
current_platform
.
is_pin_memory_available
()
@
cache
def
is_uva_available
()
->
bool
:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return
is_pin_memory_available
()
class
DeviceMemoryProfiler
:
def
__init__
(
self
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
):
...
...
@@ -1162,7 +1221,7 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
argparse
.
HelpFormatter
):
class
SortedHelpFormatter
(
argparse
.
ArgumentDefaults
HelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def
add_arguments
(
self
,
actions
):
...
...
@@ -1183,6 +1242,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if
args
is
None
:
args
=
sys
.
argv
[
1
:]
# Check for --model in command line arguments first
if
args
and
args
[
0
]
==
"serve"
:
model_in_cli_args
=
any
(
arg
==
'--model'
for
arg
in
args
)
if
model_in_cli_args
:
raise
ValueError
(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option."
)
if
'--config'
in
args
:
args
=
self
.
_pull_args_from_config
(
args
)
...
...
@@ -1266,19 +1335,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args
=
self
.
_load_config_file
(
file_path
)
# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
#
optionally
followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if
args
[
0
]
==
"serve"
:
if
index
==
1
:
model_in_cli
=
len
(
args
)
>
1
and
not
args
[
1
].
startswith
(
'-'
)
model_in_config
=
any
(
arg
==
'--model'
for
arg
in
config_args
)
if
not
model_in_cli
and
not
model_in_config
:
raise
ValueError
(
"No model_tag specified! Please check your command-line"
" arguments."
)
args
=
[
args
[
0
]]
+
[
args
[
1
]
]
+
config_args
+
args
[
2
:
index
]
+
args
[
index
+
2
:]
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if
model_in_cli
:
# Model specified as positional arg, keep CLI version
args
=
[
args
[
0
]]
+
[
args
[
1
]
]
+
config_args
+
args
[
2
:
index
]
+
args
[
index
+
2
:]
else
:
# No model in CLI, use config if available
args
=
[
args
[
0
]
]
+
config_args
+
args
[
1
:
index
]
+
args
[
index
+
2
:]
else
:
args
=
[
args
[
0
]]
+
config_args
+
args
[
1
:
index
]
+
args
[
index
+
2
:]
...
...
@@ -1296,9 +1375,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension
:
str
=
file_path
.
split
(
'.'
)[
-
1
]
if
extension
not
in
(
'yaml'
,
'yml'
):
raise
ValueError
(
...
...
@@ -1691,18 +1768,21 @@ class ClassRegistry(UserDict[Type[T], _V]):
return
any
(
cls
in
self
.
data
for
cls
in
key
.
mro
())
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
weak_ref_tensor
(
tensor
:
Any
)
->
Any
:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
return
torch
.
ops
.
_C
.
weak_ref_tensor
(
tensor
)
if
isinstance
(
tensor
,
torch
.
Tensor
):
return
torch
.
ops
.
_C
.
weak_ref_tensor
(
tensor
)
else
:
return
tensor
def
weak_ref_tensors
(
tensors
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
]
]:
)
->
Union
[
torch
.
Tensor
,
list
[
Any
],
tuple
[
Any
],
Any
]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
...
...
@@ -1716,6 +1796,14 @@ def weak_ref_tensors(
raise
ValueError
(
"Invalid type for tensors"
)
def
get_cuda_view_from_cpu_tensor
(
cpu_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert
cpu_tensor
.
is_pinned
(),
"CPU tensor must be pinned"
return
torch
.
ops
.
_C
.
get_cuda_view_from_cpu_tensor
(
cpu_tensor
)
def
is_in_doc_build
()
->
bool
:
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
...
...
@@ -2247,11 +2335,11 @@ def make_zmq_socket(
if
socket_type
==
zmq
.
constants
.
PULL
:
socket
.
setsockopt
(
zmq
.
constants
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVBUF
,
buf_size
)
socket
.
connect
(
path
)
socket
.
bind
(
path
)
elif
socket_type
==
zmq
.
constants
.
PUSH
:
socket
.
setsockopt
(
zmq
.
constants
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDBUF
,
buf_size
)
socket
.
bind
(
path
)
socket
.
connect
(
path
)
else
:
raise
ValueError
(
f
"Unknown Socket Type:
{
socket_type
}
"
)
...
...
@@ -2259,7 +2347,11 @@ def make_zmq_socket(
@
contextlib
.
contextmanager
def
zmq_socket_ctx
(
path
:
str
,
socket_type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
def
zmq_socket_ctx
(
path
:
str
,
socket_type
:
Any
,
linger
:
int
=
0
,
)
->
Iterator
[
zmq
.
Socket
]:
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
...
...
@@ -2270,7 +2362,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger
.
debug
(
"Got Keyboard Interrupt."
)
finally
:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
linger
)
def
is_in_ray_actor
():
...
...
@@ -2567,3 +2659,33 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
return
wrapper
return
decorator
# Only relevant for models using ALiBi (e.g, MPT)
def
check_use_alibi
(
model_config
:
ModelConfig
)
->
bool
:
return
(
getattr
(
model_config
.
hf_text_config
,
"alibi"
,
False
)
# Falcon
or
(
"BloomForCausalLM"
in
getattr
(
model_config
.
hf_config
,
"architectures"
,
[]))
# Bloom
or
getattr
(
model_config
.
hf_text_config
,
"position_encoding_type"
,
""
)
==
"alibi"
# codellm_1b_alibi
or
(
hasattr
(
model_config
.
hf_text_config
,
"attn_config"
)
# MPT
and
model_config
.
hf_text_config
.
attn_config
.
get
(
"alibi"
,
False
)))
def
sha256
(
input
)
->
int
:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
arbitrary Python objects to be used. Note that this function does
not use a hash seed—if you need one, prepend it explicitly to the input.
Args:
input: Any picklable Python object.
Returns:
An integer representing the SHA-256 hash of the serialized input.
"""
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
vllm/v1/attention/backends/flash_attn.py
View file @
fcfc474d
...
...
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
tensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
page_size
assert
attn_chunk_size
%
page_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by page_size
{
page_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
class
FlashAttentionMetadataBuilder
:
...
...
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
query_start_loc_cpu
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
]
query_start_loc
=
query_start_loc_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
# for local attention
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table
,
self
.
runner
.
block_size
,
)
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_block_table
=
virt_block_table
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
virt_k_seqlens_np
.
max
(),
)
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
...
...
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
)
return
attn_metadata
...
...
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
self
.
use_irope
=
use_irope
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
and
not
flash_attn_supports_fp8
():
...
...
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
...
...
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
# Regular attention (common case).
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_
len
,
seqused_k
=
attn_metadata
.
seq_lens
,
max_seqlen_k
=
attn_metadata
.
max_seq
_
len
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seq
len
_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen
_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
...
...
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
cascade_attention
(
output
[:
num_actual_tokens
],
...
...
vllm/v1/attention/backends/pallas.py
View file @
fcfc474d
...
...
@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK
=
32
NUM_KV_PAGES_PER_BLOCK
=
128
class
PallasAttentionBackend
(
AttentionBackend
):
...
...
@@ -41,7 +37,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
*
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -92,6 +88,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -99,15 +97,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
@@ -118,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
tpu_version
=
torch_xla
.
tpu
.
version
()
if
tpu_version
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if
tpu_version
==
4
:
self
.
vmem_limit_bytes
=
16
*
1024
*
1024
else
:
self
.
vmem_limit_bytes
=
64
*
1024
*
1024
def
forward
(
self
,
...
...
@@ -132,7 +118,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -142,14 +128,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads * head_size])
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if
kv_cache
[
0
]
.
numel
()
==
0
:
if
kv_cache
.
numel
()
==
0
:
if
output
is
None
:
output
=
torch
.
ones_like
(
query
)
return
output
...
...
@@ -158,24 +143,28 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
kv_cache
if
kv_cache
[
0
].
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
k
ey_cache
,
value
_cache
,
slot_mapping
)
write_to_kv_cache
(
key
,
value
,
k
v
_cache
,
slot_mapping
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
key_cache
,
value_cache
,
kv_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
query_start_loc
,
attn_metadata
.
num_seqs
,
num_kv_pages_per_block
=
NUM_KV_PAGES_PER_BLOCK
,
num_queries_per_block
=
NUM_QUERIES_PER_BLOCK
,
vmem_limit_bytes
=
self
.
vmem_limit_bytes
,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block
=
None
,
num_queries_per_block
=
None
,
vmem_limit_bytes
=
None
,
use_kernel
=
True
,
sm_scale
=
self
.
scale
)
sm_scale
=
self
.
scale
,
sliding_window
=
self
.
sliding_window
,
soft_cap
=
self
.
logits_soft_cap
,
)
return
output
.
reshape
(
num_tokens
,
hidden_size
)
...
...
@@ -183,23 +172,27 @@ class PallasAttentionBackendImpl(AttentionImpl):
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
_
,
_
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
num_kv_heads
=
num_combined_kv_heads
//
2
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
head_size
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache
,
True
)
key_cache
=
key_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
kv_cache
=
kv_cache
.
flatten
(
0
,
1
)
kv_cache
.
index_copy_
(
0
,
slot_mapping
,
kv
)
vllm/v1/attention/backends/triton_attn.py
View file @
fcfc474d
...
...
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
use_irope
=
use_irope
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -156,23 +158,41 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
)
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
sequesd_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
sequesd_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
query
=
quer
y
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
]
,
k
v
_cache
_dtype
=
self
.
kv_cache_dtyp
e
,
key
_cache
=
key
_cache
,
value_cache
=
value_cach
e
,
block_table
=
attn_metadata
.
block_table
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq
_
len
s
,
max_query_len
=
attn_metadata
.
max_query_
len
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
ke
y
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
k
ey
_cache
=
key_cach
e
,
value
_cache
=
value
_cache
,
block_table
=
block_tabl
e
,
query_start_loc
=
cu_seqlens_q
,
seq_lens
=
sequesd_k
,
max_seq_len
=
max_
seqlen
_k
,
max_query_len
=
max_seq
len
_q
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
return
output
vllm/v1/core/block_pool.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
typing
import
Optional
from
typing
import
Callable
,
Optional
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
...
...
@@ -15,10 +15,10 @@ logger = init_logger(__name__)
class
BlockPool
:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
...
...
@@ -27,6 +27,7 @@ class BlockPool:
"""
def
__init__
(
self
,
num_gpu_blocks
:
int
,
enable_caching
:
bool
):
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
enable_caching
=
enable_caching
# All kv-cache blocks.
...
...
@@ -50,6 +51,11 @@ class BlockPool:
self
.
cached_block_hash_to_block
:
dict
[
BlockHashType
,
dict
[
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self
.
null_block
=
self
.
free_block_queue
.
popleft
()
def
get_cached_block
(
self
,
block_hash
:
BlockHashType
)
->
Optional
[
KVCacheBlock
]:
"""Get a cached block by the block hash, or None if cache miss.
...
...
@@ -75,11 +81,12 @@ class BlockPool:
num_cached_blocks
:
int
,
num_full_blocks
:
int
,
block_size
:
int
,
hash_fn
:
Callable
,
)
->
None
:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
...
...
@@ -87,12 +94,13 @@ class BlockPool:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
hash_fn: The hash function to use for block hashes.
"""
if
num_cached_blocks
==
num_full_blocks
:
return
...
...
@@ -138,7 +146,7 @@ class BlockPool:
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_hash
=
hash_block_tokens
(
hash_fn
,
prev_block_hash_value
,
block_tokens
,
extra_keys
)
block_hashes
.
append
(
block_hash
)
...
...
@@ -212,7 +220,7 @@ class BlockPool:
for
block
in
blocks
:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
remove
(
block
)
block
.
incr_ref
()
...
...
@@ -226,7 +234,8 @@ class BlockPool:
"""
for
block
in
ordered_blocks
:
block
.
decr_ref
()
if
block
.
ref_cnt
==
0
:
# null_block should not be added to the free list.
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
append
(
block
)
def
reset_prefix_cache
(
self
)
->
bool
:
...
...
@@ -239,10 +248,10 @@ class BlockPool:
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
())
if
num_used_blocks
>
0
:
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet"
,
num_used_blocks
)
"blocks (%d) are not freed yet"
,
num_used_blocks
-
1
)
return
False
# Remove all hashes so that no new blocks will hit.
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
fcfc474d
...
...
@@ -3,7 +3,7 @@
from
typing
import
TYPE_CHECKING
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
M
ULTIMODAL_REGISTRY
from
vllm.multimodal
import
M
ultiModalRegistry
from
vllm.v1.request
import
Request
if
TYPE_CHECKING
:
...
...
@@ -67,6 +67,7 @@ class EncoderCacheManager:
def
compute_encoder_budget
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
mm_registry
:
MultiModalRegistry
,
)
->
tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
...
...
@@ -74,6 +75,7 @@ def compute_encoder_budget(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
...
...
@@ -89,7 +91,11 @@ def compute_encoder_budget(
(
encoder_compute_budget
,
encoder_cache_size
,
)
=
_compute_encoder_budget_multimodal
(
model_config
,
scheduler_config
)
)
=
_compute_encoder_budget_multimodal
(
model_config
,
scheduler_config
,
mm_registry
,
)
return
encoder_compute_budget
,
encoder_cache_size
...
...
@@ -97,6 +103,7 @@ def compute_encoder_budget(
def
_compute_encoder_budget_multimodal
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
mm_registry
:
MultiModalRegistry
,
)
->
tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
...
...
@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
...
...
@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence.
"""
max_tokens_by_modality_dict
=
MULTIMODAL_REGISTRY
.
get_max_tokens_per_item_by_nonzero_modality
(
# noqa: E501
model_config
)
max_tokens_by_modality_dict
=
mm_registry
\
.
get_max_tokens_per_item_by_nonzero_modality
(
model_config
)
if
not
max_tokens_by_modality_dict
:
logger
.
warning
(
...
...
Prev
1
…
19
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment