Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1231 additions
and
306 deletions
+1231
-306
vllm/platforms/interface.py
vllm/platforms/interface.py
+16
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+52
-16
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+7
-1
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+5
-1
vllm/reasoning/abs_reasoning_parsers.py
vllm/reasoning/abs_reasoning_parsers.py
+47
-54
vllm/reasoning/deepseek_r1_reasoning_parser.py
vllm/reasoning/deepseek_r1_reasoning_parser.py
+67
-56
vllm/reasoning/granite_reasoning_parser.py
vllm/reasoning/granite_reasoning_parser.py
+362
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+16
-3
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/eagle.py
vllm/transformers_utils/configs/eagle.py
+13
-2
vllm/transformers_utils/configs/skyworkr1v.py
vllm/transformers_utils/configs/skyworkr1v.py
+53
-0
vllm/transformers_utils/processors/deepseek_vl2.py
vllm/transformers_utils/processors/deepseek_vl2.py
+1
-1
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+8
-6
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+46
-3
vllm/utils.py
vllm/utils.py
+195
-73
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+236
-14
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+34
-41
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+37
-17
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+22
-13
vllm/v1/core/encoder_cache_manager.py
vllm/v1/core/encoder_cache_manager.py
+12
-4
No files found.
vllm/platforms/interface.py
View file @
fcfc474d
...
...
@@ -12,9 +12,10 @@ import torch
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.utils
import
FlexibleArgumentParser
else
:
ModelConfig
=
None
VllmConfig
=
None
FlexibleArgumentParser
=
None
...
...
@@ -376,6 +377,20 @@ class Platform:
or
parallel_config
.
distributed_executor_backend
==
"external_launcher"
)
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return
False
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
"""
Returns if custom allreduce is supported on the current platform
"""
return
False
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/rocm.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
os
from
functools
import
lru_cache
,
wraps
from
functools
import
cache
,
lru_cache
,
wraps
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -12,15 +12,17 @@ from vllm.logger import init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
None
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
try
:
from
amdsmi
import
(
amdsmi_get_gpu_asic_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
AmdSmiException
,
amdsmi_topo_get_link_type
)
from
amdsmi
import
(
AmdSmiException
,
amdsmi_get_gpu_asic_info
,
amdsmi_get_processor_handles
,
amdsmi_init
,
amdsmi_shut_down
,
amdsmi_topo_get_link_type
)
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from amdsmi with %r"
,
e
)
...
...
@@ -97,6 +99,26 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return
device_id
@
cache
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
,
sliding_window
:
int
)
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
ON_NAVI
=
"gfx1"
in
GPU_ARCH
ON_MI250_MI300
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
# rocm custom page attention not support on navi (gfx1*)
return
(
ON_MI250_MI300
and
not
ON_NAVI
and
(
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
device_name
:
str
=
"rocm"
...
...
@@ -135,22 +157,15 @@ class RocmPlatform(Platform):
@
classmethod
@
lru_cache
(
maxsize
=
8
)
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
Optional
[
DeviceCapability
]:
major
,
minor
=
torch
.
cuda
.
get_device_capability
(
device_id
)
return
DeviceCapability
(
major
=
major
,
minor
=
minor
)
@
classmethod
@
with_amdsmi_context
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
handle
=
amdsmi_get_processor_handles
()[
physical_device_id
]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return
torch
.
cuda
.
get_device_name
(
device_id
)
@
staticmethod
def
is_fully_connected_nvlink_or_xgmi
(
physical_device_ids
:
List
[
int
])
->
bool
:
@
with_amdsmi_context
def
is_fully_connected
(
physical_device_ids
:
List
[
int
])
->
bool
:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
...
...
@@ -172,6 +187,15 @@ class RocmPlatform(Platform):
return
False
return
True
@
classmethod
@
with_amdsmi_context
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
physical_device_id
=
device_id_to_physical_device_id
(
device_id
)
handle
=
amdsmi_get_processor_handles
()[
physical_device_id
]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return
torch
.
cuda
.
get_device_name
(
device_id
)
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
device_props
=
torch
.
cuda
.
get_device_properties
(
device_id
)
...
...
@@ -276,3 +300,15 @@ class RocmPlatform(Platform):
return
torch
.
float8_e4m3fnuz
else
:
return
torch
.
float8_e4m3fn
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
# V1 support on AMD gpus is experimental
return
True
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
# We only enable custom allreduce for MI300 series
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
supported_archs
=
[
'gfx94'
]
return
any
(
gfx
in
gcn_arch
for
gfx
in
supported_archs
)
vllm/platforms/tpu.py
View file @
fcfc474d
...
...
@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
else
:
ModelConfig
=
None
VllmConfig
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -127,3 +128,8 @@ class TpuPlatform(Platform):
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
# V1 support on TPU is experimental
return
True
vllm/
entrypoints/openai/
reasoning
_parsers
/__init__.py
→
vllm/reasoning/__init__.py
View file @
fcfc474d
...
...
@@ -2,7 +2,11 @@
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
__all__
=
[
"ReasoningParser"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
"ReasoningParser"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
,
"GraniteReasoningParser"
,
]
vllm/
entrypoints/openai/
reasoning
_parsers
/abs_reasoning_parsers.py
→
vllm/reasoning/abs_reasoning_parsers.py
View file @
fcfc474d
...
...
@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return
self
.
model_tokenizer
.
get_vocab
()
@
abstractmethod
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@
abstractmethod
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@
abstractmethod
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
...
...
@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!"
)
@
abstractmethod
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
...
...
@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!"
)
# TODO: need to rebase by PR #14428
@
abstractmethod
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise
NotImplementedError
(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!"
)
class
ReasoningParserManager
:
...
...
@@ -125,14 +115,16 @@ class ReasoningParserManager:
if
name
in
cls
.
reasoning_parsers
:
return
cls
.
reasoning_parsers
[
name
]
raise
KeyError
(
f
"reasoning helper: '
{
name
}
' not found in "
"
reasoning_parsers"
)
raise
KeyError
(
f
"reasoning helper: '
{
name
}
' not found in
reasoning_parsers"
)
@
classmethod
def
_register_module
(
cls
,
def
_register_module
(
cls
,
module
:
type
,
module_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
)
->
None
:
force
:
bool
=
True
,
)
->
None
:
if
not
issubclass
(
module
,
ReasoningParser
):
raise
TypeError
(
"module must be subclass of ReasoningParser, "
f
"but got
{
type
(
module
)
}
"
)
...
...
@@ -152,7 +144,8 @@ class ReasoningParserManager:
cls
,
name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
,
force
:
bool
=
True
,
module
:
Union
[
type
,
None
]
=
None
)
->
Union
[
type
,
Callable
]:
module
:
Union
[
type
,
None
]
=
None
,
)
->
Union
[
type
,
Callable
]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
...
...
vllm/
entrypoints/openai/
reasoning
_parsers
/deepseek_r1_reasoning_parser.py
→
vllm/reasoning/deepseek_r1_reasoning_parser.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
...
...
@@ -8,9 +7,8 @@ from transformers import PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers
import
(
ReasoningParser
,
ReasoningParserManager
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
...
...
@@ -24,39 +22,38 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
start_token_id
:
int
end_token_id
:
int
start_token
:
str
=
"<think>"
end_token
:
str
=
"</think>"
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_start_token
=
"<think>"
self
.
think_end_token
=
"</think>"
self
.
reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_token
}
(.*?)
{
self
.
think_end_token
}
"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
think_start_token_id
=
self
.
vocab
.
get
(
self
.
think_start_token
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
(
self
.
think_start_token_id
is
None
or
self
.
think_end_token_id
is
None
):
self
.
start_token_id
=
self
.
vocab
.
get
(
self
.
start_token
)
self
.
end_token_id
=
self
.
vocab
.
get
(
self
.
end_token
)
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
raise
RuntimeError
(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
# TODO: need to rebase by PR #14428
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_
end_token_id
in
input_ids
return
self
.
end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract the content after the end tokens
"""
if
self
.
think_
end_token_id
not
in
input_ids
[:
-
1
]:
if
self
.
end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_
end_token_id
)
+
1
:]
return
input_ids
[
input_ids
.
index
(
self
.
end_token_id
)
+
1
:]
def
extract_reasoning_content_streaming
(
self
,
...
...
@@ -77,22 +74,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if
len
(
delta_token_ids
)
==
1
and
(
delta_token_ids
[
0
]
in
[
self
.
think_
start_token_id
,
self
.
think_
end_token_id
self
.
start_token_id
,
self
.
end_token_id
]):
return
None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if
self
.
think_
start_token_id
in
previous_token_ids
:
if
self
.
think_
end_token_id
in
delta_token_ids
:
if
self
.
start_token_id
in
previous_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
elif
self
.
end_token_id
in
previous_token_ids
:
# <think> in previous, </think> in previous,
# reasoning content continues
return
DeltaMessage
(
content
=
delta_text
)
...
...
@@ -100,17 +99,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_
start_token_id
in
delta_token_ids
:
if
self
.
think_
end_token_id
in
delta_token_ids
:
elif
self
.
start_token_id
in
delta_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_
start_token
)
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
start_index
=
delta_text
.
find
(
self
.
start_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[
start_index
+
len
(
self
.
think_start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
len
(
self
.
start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
else
:
# <think> in delta, no </think> in delta,
# reasoning content continues
...
...
@@ -119,15 +119,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_
end_token_id
in
delta_token_ids
:
if
self
.
end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index
=
delta_text
.
find
(
self
.
think_
end_token
)
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
elif
self
.
end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
return
DeltaMessage
(
content
=
delta_text
)
else
:
...
...
@@ -137,25 +139,34 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""
Extract reasoning content from the model output.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the start token is present in the model output, remove it
# if it is present.
model_output_parts
=
model_output
.
partition
(
self
.
start_token
)
model_output
=
model_output_parts
[
2
]
if
model_output_parts
[
1
]
else
model_output_parts
[
0
]
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_
end_token
not
in
model_output
:
if
self
.
end_token
not
in
model_output
:
return
model_output
,
None
else
:
# Add a start token if it's missing to keep compatibility.
if
self
.
think_start_token
not
in
model_output
:
model_output
=
f
"
{
self
.
think_start_token
}{
model_output
}
"
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
end_index
=
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
)
final_output
=
model_output
[
end_index
:]
if
len
(
final_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
final_output
reasoning_content
,
_
,
content
=
model_output
.
partition
(
self
.
end_token
)
# If the end token is not found, return the model output as is.
# It should not happen since we already checked for the presence
# of the end token.
# If generation stops right after end-of-think, return null content
final_content
=
content
or
None
return
reasoning_content
,
final_content
vllm/reasoning/granite_reasoning_parser.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"granite"
)
class
GraniteReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for IBM Granite.
IBM granite models currently use "Here is my thought process:"
and "Here is my response:" to separate its thinking / response outputs.
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
# NOTE: There have been some observed occurrences of quantized
# instances of the current models using "Here's" instead of "Here is",
# so to be safe, we match on both.
self
.
think_start_expr
=
r
"(?:Here's|Here is) my thought process:"
self
.
response_start_expr
=
r
"(?:Here's|Here is) my response:"
self
.
reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_expr
}
(.*?)
{
self
.
response_start_expr
}
(.*)"
,
re
.
DOTALL
)
self
.
valid_think_starts
=
[
"Here's my thought process:"
,
"Here is my thought process:"
]
self
.
valid_response_starts
=
[
"Here's my response:"
,
"Here is my response:"
]
# Substrings to match for sequence boundaries on raw text
self
.
seq_boundary_end
=
":"
self
.
seq_boundary_start
=
"Here"
# The longest any thinking / start of response message can be
self
.
longest_think_start
=
max
(
len
(
think_start
)
for
think_start
in
self
.
valid_think_starts
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match
=
self
.
reasoning_regex
.
findall
(
model_output
)
if
not
re_match
:
return
None
,
model_output
reasoning_content
,
response_content
=
re_match
[
0
]
if
not
response_content
:
return
reasoning_content
,
None
return
reasoning_content
,
response_content
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""Extract the reasoning content / content emitted by granite models;
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
NOTE: Granite models do not use a special token to start their reasoning
and response sections; instead they have token sequences, e.g.,
Here is my thought process: Foo Here is my response: Bar
This increases the complexity of correctly handling streams, since we
need to watch for specific sequences and correctly parse them without
dropping content that is potentially overlapping & spanning multiple
delta messages.
Args:
previous_text (str): Previous text outside of this delta message.
current_text (str): Previous text + delta text.
delta_text (str): Text to consider and parse content from.
previous_token_ids (Sequence[int]): Token IDs of previous_text.
current_token_ids (Sequence[int]): Token IDs of current_text.
delta_token_ids (Sequence[int]): Token IDs of delta_text.
Returns:
Union[DeltaMessage, None]
DeltaMessage with either reasoning content or content, or None.
"""
reasoning_content
,
resp_seq_len
,
content
=
self
.
_get_content_sections
(
current_text
)
# Either we haven't finished the start of the reasoning sequence,
# or the model is generating something unexpected.
if
not
reasoning_content
:
delta_message
=
self
.
_get_delta_message_with_no_reasoning_bounds
(
current_text
,
delta_text
)
# We have a start of reasoning message, but have not yet finished
# the start of response sequence.
elif
not
content
:
delta_message
=
self
.
_get_delta_message_with_no_response_bounds
(
current_text
,
reasoning_content
,
delta_text
)
# We've finished both the start of reasoning and start of response seq.
else
:
# This should never happen since we matched on the response
assert
resp_seq_len
is
not
None
delta_message
=
self
.
_get_delta_message_with_both_bounds
(
delta_text
,
reasoning_content
,
content
,
current_text
,
resp_seq_len
)
if
not
delta_message
.
content
and
not
delta_message
.
reasoning_content
:
return
None
return
delta_message
#### Implementation details of stream parsing for granite models
def
_is_reasoning_start_substr
(
self
,
text
:
str
)
->
bool
:
"""Check if a text matches one of the possible start reasoning seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible reasoning start seqs match.
"""
return
any
(
think_start
.
startswith
(
text
)
for
think_start
in
self
.
valid_think_starts
)
def
_is_response_start_substr
(
self
,
text
:
str
)
->
bool
:
"""Check if a text matches one of the possible start response seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible response start seqs match.
"""
return
any
(
response_start
.
startswith
(
text
)
for
response_start
in
self
.
valid_response_starts
)
def
_get_delta_message_with_no_reasoning_bounds
(
self
,
current_text
:
str
,
delta_text
:
str
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has not yet completed
its start of reasoning sequence.
Args:
current_text (str): The full previous + delta text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
prev_longest_length
=
len
(
current_text
)
-
len
(
delta_text
)
is_substr
=
self
.
_is_reasoning_start_substr
(
current_text
)
was_substr
=
self
.
_is_reasoning_start_substr
(
current_text
[:
prev_longest_length
])
# Check if we just generated something NOT in the special token seq;
# if so, add everything that we previously skipped with this delta
# message and append everything to content in the future.
if
was_substr
and
not
is_substr
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
current_text
,
)
if
is_substr
:
# Might still be in the special token sequence; return nothing
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
# Otherwise the sequence has already been broken and we already
# corrected; just return the delta text as normal content.
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
delta_text
)
def
_get_delta_message_with_no_response_bounds
(
self
,
current_text
:
str
,
reasoning_content
:
str
,
delta_text
:
str
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has both reasoning
content with no (response) content. NOTE that we may have overlapping
tokens with the start of reasoning / start of response sequences on
either side of the delta text.
Args:
current_text (str): The full previous + delta text.
reasoning_content (str): reasoning content from current_text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# If we have no reasoning content or explicitly end with the start of
# response sequence, we are in transition to the response; need to be
# careful here, since the final token (:) will match the reasoning
# content and fully parse it out; we should not pass the : back.
ends_with_start_response_seq
=
any
(
current_text
.
endswith
(
response_start
)
for
response_start
in
self
.
valid_response_starts
)
if
reasoning_content
is
None
or
ends_with_start_response_seq
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
# Consider previous / current text only within context of the reasoning
previous_text
=
reasoning_content
[:
-
len
(
delta_text
)]
current_text
=
reasoning_content
# We need to be careful about adding unfinished response sequences;
# Find the place at which we MIGHT be starting a response sequence
prev_idx
=
previous_text
.
rfind
(
self
.
seq_boundary_start
)
delta_idx
=
delta_text
.
rfind
(
self
.
seq_boundary_start
)
# Check the state of potential start of response substring matches.
prev_was_substr
=
self
.
_is_response_start_substr
(
previous_text
[
prev_idx
:])
if
prev_idx
>=
0
else
False
delta_continues_substr
=
self
.
_is_response_start_substr
(
current_text
[
prev_idx
:])
if
prev_idx
>=
0
else
False
delta_new_substr
=
self
.
_is_response_start_substr
(
delta_text
[
delta_idx
:])
if
delta_idx
>=
0
else
False
# Delta only contains potential continued response sequence text.
if
delta_continues_substr
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
None
)
if
not
prev_was_substr
:
# Delta may be starting a new response seq but has other text too.
if
delta_new_substr
:
return
DeltaMessage
(
reasoning_content
=
delta_text
[:
delta_idx
],
content
=
None
)
# Normal case for most reasoning text (no potential special seqs).
return
DeltaMessage
(
reasoning_content
=
delta_text
,
content
=
None
)
# The substring that previously seemed to be a potential response
# seq wasn't one; we need to add the content to the delta message,
# and also slice off the potential response sequence
elif
delta_new_substr
:
reasoning_content
=
previous_text
[
prev_idx
:]
+
delta_text
[:
delta_idx
]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
None
)
# No new substring yet, and we broke our old one; take the whole delta
return
DeltaMessage
(
reasoning_content
=
previous_text
[
prev_idx
:]
+
delta_text
,
content
=
None
,
)
def
_get_delta_message_with_both_bounds
(
self
,
delta_text
:
str
,
reasoning_content
:
str
,
response_content
:
str
,
current_text
:
str
,
response_seq_len
:
int
,
)
->
DeltaMessage
:
"""Parse the delta message when the current text has both reasoning
content and normal (response) content.
Args:
delta_text (str): Text to consider and parse content from.
reasoning_content (str): reasoning content from current_text.
response_content (str): response content from current_text.
current_text (str): The full previous + delta text.
response_seq_len(str): Len of the complete response sequence used.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# Always have content; take length to the end
delta_content
=
delta_text
[
-
len
(
response_content
):]
reasoning_end_idx
=
len
(
delta_text
)
-
(
len
(
response_content
)
+
response_seq_len
)
if
reasoning_end_idx
<
0
:
delta_reasoning_content
=
None
else
:
# Get the starting offset
start_reasoning_content_idx
=
len
(
reasoning_content
)
+
response_seq_len
+
len
(
response_content
)
-
1
delta_offset
=
len
(
current_text
)
-
len
(
delta_text
)
start_offset
=
start_reasoning_content_idx
-
delta_offset
if
start_offset
<
0
:
start_offset
=
0
delta_reasoning_content
=
delta_text
[
start_offset
:
reasoning_end_idx
]
return
DeltaMessage
(
reasoning_content
=
delta_reasoning_content
,
content
=
delta_content
,
)
def
_get_content_sections
(
self
,
current_text
:
str
)
->
tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
str
]]:
"""Parse the text to extract the reasoning content / content
if we have them.
Args:
current_text (str): The full previous + delta text.
Returns:
tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
containing the reasoning content, the length of the response seq
(if there is one) and the non-reasoning content.
"""
current_chunk_start
=
0
start_reasoning_content
=
None
parsed_content
=
False
delimiter_idxs
=
[
idx
for
idx
,
char
in
enumerate
(
current_text
)
if
char
==
self
.
seq_boundary_end
]
for
current_chunk_end
in
delimiter_idxs
:
current_chunk
=
current_text
[
current_chunk_start
:
current_chunk_end
]
# Check to see if the start of reasoning seq if complete
if
start_reasoning_content
is
None
:
for
think_start
in
self
.
valid_think_starts
:
if
current_chunk
==
think_start
[:
-
1
]:
start_reasoning_content
=
current_chunk_end
+
1
current_chunk_start
=
current_chunk_end
+
1
break
# Check to see if the start of response seq if complete
elif
not
parsed_content
:
for
response_start
in
self
.
valid_response_starts
:
if
current_chunk
[
-
len
(
response_start
)
+
1
:]
==
response_start
[:
-
1
]:
# Mark end of reasoning and start response content
# after the start of response sequence.
end_reasoning_content
=
current_chunk_end
-
len
(
response_start
)
reasoning_content
=
current_text
[
start_reasoning_content
:
end_reasoning_content
]
response_content
=
current_text
[
current_chunk_end
+
1
:]
return
reasoning_content
,
len
(
response_start
),
response_content
if
start_reasoning_content
and
not
parsed_content
:
return
current_text
[
start_reasoning_content
:],
None
,
None
return
None
,
None
,
None
vllm/transformers_utils/config.py
View file @
fcfc474d
...
...
@@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
Olmo2Config
,
RWConfig
,
S
olarConfig
,
Telechat2
Config
,
UltravoxConfig
)
S
kyworkR1VChatConfig
,
Solar
Config
,
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
...
...
@@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D"
:
NVLM_D_Config
,
"olmo2"
:
Olmo2Config
,
"solar"
:
SolarConfig
,
"skywork_chat"
:
SkyworkR1VChatConfig
,
"telechat"
:
Telechat2Config
,
"ultravox"
:
UltravoxConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
...
...
@@ -261,6 +262,11 @@ def get_config(
MISTRAL_CONFIG_NAME
,
revision
=
revision
):
config_format
=
ConfigFormat
.
MISTRAL
else
:
raise
ValueError
(
"Could not detect config format for no config file found. "
"Ensure your model has either config.json (HF format) "
"or params.json (Mistral format)."
)
except
Exception
as
e
:
error_message
=
(
...
...
@@ -323,7 +329,14 @@ def get_config(
elif
config_format
==
ConfigFormat
.
MISTRAL
:
config
=
load_params_config
(
model
,
revision
,
token
=
HF_TOKEN
,
**
kwargs
)
else
:
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
"
)
supported_formats
=
[
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
]
raise
ValueError
(
f
"Unsupported config format:
{
config_format
}
. "
f
"Supported formats are:
{
', '
.
join
(
supported_formats
)
}
. "
f
"Ensure your model uses one of these configuration formats "
f
"or specify the correct format explicitly."
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
...
...
vllm/transformers_utils/configs/__init__.py
View file @
fcfc474d
...
...
@@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from
vllm.transformers_utils.configs.nemotron
import
NemotronConfig
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.olmo2
import
Olmo2Config
from
vllm.transformers_utils.configs.skyworkr1v
import
SkyworkR1VChatConfig
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.telechat2
import
Telechat2Config
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
...
...
@@ -42,6 +43,7 @@ __all__ = [
"NemotronConfig"
,
"NVLM_D_Config"
,
"Olmo2Config"
,
"SkyworkR1VChatConfig"
,
"SolarConfig"
,
"Telechat2Config"
,
"UltravoxConfig"
,
...
...
vllm/transformers_utils/configs/eagle.py
View file @
fcfc474d
...
...
@@ -5,6 +5,8 @@ from typing import Optional, Union
from
transformers
import
AutoConfig
,
PretrainedConfig
from
vllm.transformers_utils.configs.deepseek_vl2
import
DeepseekV2Config
class
EAGLEConfig
(
PretrainedConfig
):
model_type
=
"eagle"
...
...
@@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
truncated_vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
model_config
=
None
if
model
is
None
else
(
AutoConfig
.
for_model
(
**
model
)
if
isinstance
(
model
,
dict
)
else
model
)
model_config
:
Union
[
PretrainedConfig
,
DeepseekV2Config
,
None
]
if
isinstance
(
model
,
dict
):
archs
=
model
.
get
(
"architectures"
,
[])
target_archs
=
[
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
]
if
any
(
target_arch
in
archs
for
target_arch
in
target_archs
):
# AutoConfig does not support DeepSeek MoE models yet
model_config
=
DeepseekV2Config
(
**
model
)
else
:
model_config
=
AutoConfig
.
for_model
(
**
model
)
else
:
model_config
=
model
for
k
,
v
in
kwargs
.
items
():
if
k
!=
"architectures"
and
k
!=
"model_type"
and
hasattr
(
...
...
vllm/transformers_utils/configs/skyworkr1v.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
transformers.configuration_utils
import
PretrainedConfig
class
SkyworkR1VChatConfig
(
PretrainedConfig
):
model_type
=
'internvl_chat'
is_composition
=
True
def
__init__
(
self
,
vision_config
=
None
,
llm_config
=
None
,
use_backbone_lora
=
0
,
use_llm_lora
=
0
,
select_layer
=-
1
,
force_image_size
=
None
,
downsample_ratio
=
0.5
,
template
=
None
,
dynamic_image_size
=
False
,
use_thumbnail
=
False
,
ps_version
=
'v1'
,
min_dynamic_patch
=
1
,
max_dynamic_patch
=
6
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
vision_config
is
None
:
vision_config
=
{}
if
llm_config
is
None
:
llm_config
=
{}
self
.
vision_config
=
PretrainedConfig
(
**
vision_config
)
self
.
text_config
=
PretrainedConfig
(
**
llm_config
)
self
.
use_backbone_lora
=
use_backbone_lora
self
.
use_llm_lora
=
use_llm_lora
self
.
select_layer
=
select_layer
self
.
force_image_size
=
force_image_size
self
.
downsample_ratio
=
downsample_ratio
self
.
template
=
template
self
.
dynamic_image_size
=
dynamic_image_size
self
.
use_thumbnail
=
use_thumbnail
self
.
ps_version
=
ps_version
# pixel shuffle version
self
.
min_dynamic_patch
=
min_dynamic_patch
self
.
max_dynamic_patch
=
max_dynamic_patch
vllm/transformers_utils/processors/deepseek_vl2.py
View file @
fcfc474d
...
...
@@ -226,7 +226,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
input_ids
[
input_ids
<
0
]
=
self
.
pad_id
if
inference_mode
:
#
去掉结尾的
eos token
#
Remove the ending
eos token
assert
input_ids
[
-
1
]
==
self
.
eos_id
input_ids
=
input_ids
[:
-
1
]
target_ids
=
target_ids
[:
-
1
]
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
fcfc474d
...
...
@@ -124,12 +124,14 @@ def find_tokenizer_file(files: List[str]):
matched_files
=
[
file
for
file
in
files
if
file_pattern
.
match
(
file
)]
if
len
(
matched_files
)
>
1
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern:
{
file_pattern
}
. Make sure only one Mistral "
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern: `
{
file_pattern
.
pattern
}
`. Make sure only one Mistral "
f
"tokenizer is present in
{
files
}
."
)
elif
len
(
matched_files
)
==
0
:
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern:
{
file_pattern
}
. Make sure that a Mistral "
raise
OSError
(
f
"Found
{
len
(
matched_files
)
}
files matching the "
f
"pattern: `
{
file_pattern
.
pattern
}
`. Make sure that a Mistral "
f
"tokenizer is present in
{
files
}
."
)
return
matched_files
[
0
]
...
...
vllm/transformers_utils/utils.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
functools
import
cache
from
os
import
PathLike
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
from
vllm.envs
import
VLLM_MODEL_REDIRECT_PATH
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
is_s3
(
model_or_path
:
str
)
->
bool
:
return
model_or_path
.
lower
().
startswith
(
's3://'
)
...
...
@@ -17,9 +23,14 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool:
elif
model
.
suffix
==
".gguf"
:
return
True
with
open
(
model
,
"rb"
)
as
f
:
try
:
with
model
.
open
(
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
except
Exception
as
e
:
logger
.
debug
(
"Error reading file %s: %s"
,
model
,
e
)
return
False
def
modelscope_list_repo_files
(
...
...
@@ -38,3 +49,35 @@ def modelscope_list_repo_files(
if
file
[
'Type'
]
==
'blob'
]
return
files
@
cache
def
maybe_model_redirect
(
model
:
str
)
->
str
:
"""
Use model_redirect to redirect the model name to a local folder.
:param model: hf model name
:return: maybe redirect to a local folder
"""
model_redirect_path
=
VLLM_MODEL_REDIRECT_PATH
if
not
model_redirect_path
:
return
model
if
not
Path
(
model_redirect_path
).
exists
():
return
model
with
open
(
model_redirect_path
)
as
f
:
for
line
in
f
.
readlines
():
try
:
model_name
,
redirect_name
=
line
.
split
(
"
\t
"
)
if
model
==
model_name
:
redirect_name
=
redirect_name
.
strip
()
logger
.
info
(
"model redirect: [ %s ] -> [ %s ]"
,
model
,
redirect_name
)
return
redirect_name
except
Exception
:
pass
return
model
vllm/utils.py
View file @
fcfc474d
...
...
@@ -10,6 +10,7 @@ import datetime
import
enum
import
gc
import
getpass
import
hashlib
import
importlib
import
importlib.metadata
import
importlib.util
...
...
@@ -17,6 +18,7 @@ import inspect
import
ipaddress
import
multiprocessing
import
os
import
pickle
import
re
import
signal
import
socket
...
...
@@ -31,15 +33,17 @@ import uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
Mapping
)
Iterable
,
Iterator
,
KeysView
,
Mapping
)
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
types
import
MappingProxyType
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
Type
,
TypeVar
,
Union
)
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
overload
)
from
uuid
import
uuid4
import
cachetools
import
cloudpickle
import
numpy
as
np
import
numpy.typing
as
npt
...
...
@@ -57,7 +61,7 @@ import vllm.envs as envs
from
vllm.logger
import
enable_trace_function_call
,
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
logger
=
init_logger
(
__name__
)
import
json
...
...
@@ -172,6 +176,7 @@ U = TypeVar("U")
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
_T
=
TypeVar
(
"_T"
)
class
_Sentinel
:
...
...
@@ -205,6 +210,19 @@ class Counter:
self
.
counter
=
0
class
_MappingOrderCacheView
(
UserDict
[
_K
,
_V
]):
def
__init__
(
self
,
data
:
Mapping
[
_K
,
_V
],
ordered_keys
:
Mapping
[
_K
,
None
]):
super
().
__init__
(
data
)
self
.
ordered_keys
=
ordered_keys
def
__iter__
(
self
)
->
Iterator
[
_K
]:
return
iter
(
self
.
ordered_keys
)
def
keys
(
self
)
->
KeysView
[
_K
]:
return
KeysView
(
self
.
ordered_keys
)
class
CacheInfo
(
NamedTuple
):
hits
:
int
total
:
int
...
...
@@ -217,45 +235,62 @@ class CacheInfo(NamedTuple):
return
self
.
hits
/
self
.
total
class
LRUCache
(
Generic
[
_K
,
_V
]):
"""Note: This class is not thread safe!"""
class
LRUCache
(
cachetools
.
LRUCache
[
_K
,
_V
],
Generic
[
_K
,
_V
]):
def
__init__
(
self
,
capacity
:
int
)
->
None
:
self
.
cache
=
OrderedDict
[
_K
,
_V
]()
def
__init__
(
self
,
capacity
:
float
,
getsizeof
:
Optional
[
Callable
[[
_V
],
float
]]
=
None
):
super
().
__init__
(
capacity
,
getsizeof
)
self
.
pinned_items
=
set
[
_K
]()
self
.
capacity
=
capacity
self
.
_hits
=
0
self
.
_total
=
0
def
__contains__
(
self
,
key
:
_K
)
->
bool
:
return
key
in
self
.
cache
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
_K
)
->
_V
:
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
self
.
cache
.
move_to_end
(
key
)
return
value
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
run_on_remove
=
key
in
self
value
=
self
.
__getitem__
(
key
)
super
().
__delitem__
(
key
)
if
key
in
self
.
pinned_items
:
# Todo: add warning to inform that del pinned item
self
.
_unpin
(
key
)
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
def
__setitem__
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
put
(
key
,
value
)
@
property
def
cache
(
self
)
->
Mapping
[
_K
,
_V
]:
"""Return the internal cache dictionary in order (read-only)."""
return
_MappingOrderCacheView
(
self
.
_Cache__data
,
# type: ignore
self
.
order
)
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
self
.
pop
(
key
)
@
property
def
order
(
self
)
->
Mapping
[
_K
,
None
]:
"""Return the internal order dictionary (read-only)."""
return
MappingProxyType
(
self
.
_LRUCache__order
)
# type: ignore
def
stat
(
self
)
->
CacheInfo
:
return
CacheInfo
(
hits
=
self
.
_hits
,
total
=
self
.
_total
)
def
touch
(
self
,
key
:
_K
)
->
None
:
self
.
cache
.
move_to_end
(
key
)
self
.
_LRUCache__update
(
key
)
# type: ignore
def
get
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
value
:
Optional
[
_V
]
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
@
overload
def
get
(
self
,
key
:
_K
,
/
)
->
Optional
[
_V
]:
...
@
overload
def
get
(
self
,
key
:
_K
,
/
,
default
:
Union
[
_V
,
_T
])
->
Union
[
_V
,
_T
]:
...
def
get
(
self
,
key
:
_K
,
/
,
default
:
Optional
[
Union
[
_V
,
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
value
:
Optional
[
Union
[
_V
,
_T
]]
if
key
in
self
:
value
=
self
.
__getitem__
(
key
)
self
.
_hits
+=
1
else
:
...
...
@@ -264,60 +299,76 @@ class LRUCache(Generic[_K, _V]):
self
.
_total
+=
1
return
value
@
overload
def
pop
(
self
,
key
:
_K
)
->
_V
:
...
@
overload
def
pop
(
self
,
key
:
_K
,
default
:
Union
[
_V
,
_T
])
->
Union
[
_V
,
_T
]:
...
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
Union
[
_V
,
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
value
:
Optional
[
Union
[
_V
,
_T
]]
if
key
not
in
self
:
return
default
value
=
self
[
key
]
del
self
[
key
]
return
value
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
self
.
__setitem__
(
key
,
value
)
def
pin
(
self
,
key
:
_K
)
->
None
:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if
key
not
in
self
.
cache
:
if
key
not
in
self
:
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
self
.
pinned_items
.
add
(
key
)
def
_unpin
(
self
,
key
:
_K
)
->
None
:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self
.
pinned_items
.
remove
(
key
)
def
_on_remove
(
self
,
key
:
_K
,
value
:
Optional
[
_V
])
->
None
:
pass
def
remove_oldest
(
self
,
*
,
remove_pinned
:
bool
=
False
)
->
None
:
if
not
self
.
cache
:
if
len
(
self
)
==
0
:
return
self
.
popitem
(
remove_pinned
=
remove_pinned
)
def
_remove_old_if_needed
(
self
)
->
None
:
while
self
.
currsize
>
self
.
capacity
:
self
.
remove_oldest
()
def
clear
(
self
)
->
None
:
while
len
(
self
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
def
popitem
(
self
,
remove_pinned
:
bool
=
False
):
"""Remove and return the `(key, value)` pair least recently used."""
if
not
remove_pinned
:
# pop the oldest item in the cache that is not pinned
lru_key
=
next
(
(
key
for
key
in
self
.
cache
if
key
not
in
self
.
pinned_items
),
(
key
for
key
in
self
.
order
if
key
not
in
self
.
pinned_items
),
ALL_PINNED_SENTINEL
)
if
lru_key
is
ALL_PINNED_SENTINEL
:
raise
RuntimeError
(
"All items are pinned, "
"cannot remove oldest from the cache."
)
else
:
lru_key
=
next
(
iter
(
self
.
cache
))
self
.
pop
(
lru_key
)
# type: ignore
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
run_on_remove
=
key
in
self
.
cache
value
=
self
.
cache
.
pop
(
key
,
default
)
# remove from pinned items
if
key
in
self
.
pinned_items
:
self
.
_unpin
(
key
)
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
return
value
def
clear
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
cache
.
clear
()
lru_key
=
next
(
iter
(
self
.
order
))
value
=
self
.
pop
(
cast
(
_K
,
lru_key
))
return
(
lru_key
,
value
)
class
PyObjectCache
:
...
...
@@ -528,7 +579,7 @@ def get_open_port() -> int:
dp_port
=
envs
.
VLLM_DP_MASTER_PORT
while
True
:
port
=
_get_open_port
()
if
port
>=
dp_port
and
port
<
dp_port
+
10
:
if
dp_port
<=
port
<
dp_port
+
10
:
continue
return
port
return
_get_open_port
()
...
...
@@ -745,6 +796,14 @@ def is_pin_memory_available() -> bool:
return
current_platform
.
is_pin_memory_available
()
@
cache
def
is_uva_available
()
->
bool
:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return
is_pin_memory_available
()
class
DeviceMemoryProfiler
:
def
__init__
(
self
,
device
:
Optional
[
torch
.
types
.
Device
]
=
None
):
...
...
@@ -1162,7 +1221,7 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'."
)
class
SortedHelpFormatter
(
argparse
.
HelpFormatter
):
class
SortedHelpFormatter
(
argparse
.
ArgumentDefaults
HelpFormatter
):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def
add_arguments
(
self
,
actions
):
...
...
@@ -1183,6 +1242,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if
args
is
None
:
args
=
sys
.
argv
[
1
:]
# Check for --model in command line arguments first
if
args
and
args
[
0
]
==
"serve"
:
model_in_cli_args
=
any
(
arg
==
'--model'
for
arg
in
args
)
if
model_in_cli_args
:
raise
ValueError
(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option."
)
if
'--config'
in
args
:
args
=
self
.
_pull_args_from_config
(
args
)
...
...
@@ -1266,19 +1335,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args
=
self
.
_load_config_file
(
file_path
)
# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
#
optionally
followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if
args
[
0
]
==
"serve"
:
if
index
==
1
:
model_in_cli
=
len
(
args
)
>
1
and
not
args
[
1
].
startswith
(
'-'
)
model_in_config
=
any
(
arg
==
'--model'
for
arg
in
config_args
)
if
not
model_in_cli
and
not
model_in_config
:
raise
ValueError
(
"No model_tag specified! Please check your command-line"
" arguments."
)
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if
model_in_cli
:
# Model specified as positional arg, keep CLI version
args
=
[
args
[
0
]]
+
[
args
[
1
]
]
+
config_args
+
args
[
2
:
index
]
+
args
[
index
+
2
:]
else
:
# No model in CLI, use config if available
args
=
[
args
[
0
]
]
+
config_args
+
args
[
1
:
index
]
+
args
[
index
+
2
:]
else
:
args
=
[
args
[
0
]]
+
config_args
+
args
[
1
:
index
]
+
args
[
index
+
2
:]
...
...
@@ -1296,9 +1375,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension
:
str
=
file_path
.
split
(
'.'
)[
-
1
]
if
extension
not
in
(
'yaml'
,
'yml'
):
raise
ValueError
(
...
...
@@ -1691,18 +1768,21 @@ class ClassRegistry(UserDict[Type[T], _V]):
return
any
(
cls
in
self
.
data
for
cls
in
key
.
mro
())
def
weak_ref_tensor
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
weak_ref_tensor
(
tensor
:
Any
)
->
Any
:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if
isinstance
(
tensor
,
torch
.
Tensor
):
return
torch
.
ops
.
_C
.
weak_ref_tensor
(
tensor
)
else
:
return
tensor
def
weak_ref_tensors
(
tensors
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
]
]:
)
->
Union
[
torch
.
Tensor
,
list
[
Any
],
tuple
[
Any
],
Any
]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
...
...
@@ -1716,6 +1796,14 @@ def weak_ref_tensors(
raise
ValueError
(
"Invalid type for tensors"
)
def
get_cuda_view_from_cpu_tensor
(
cpu_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert
cpu_tensor
.
is_pinned
(),
"CPU tensor must be pinned"
return
torch
.
ops
.
_C
.
get_cuda_view_from_cpu_tensor
(
cpu_tensor
)
def
is_in_doc_build
()
->
bool
:
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
...
...
@@ -2247,11 +2335,11 @@ def make_zmq_socket(
if
socket_type
==
zmq
.
constants
.
PULL
:
socket
.
setsockopt
(
zmq
.
constants
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVBUF
,
buf_size
)
socket
.
connect
(
path
)
socket
.
bind
(
path
)
elif
socket_type
==
zmq
.
constants
.
PUSH
:
socket
.
setsockopt
(
zmq
.
constants
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDBUF
,
buf_size
)
socket
.
bind
(
path
)
socket
.
connect
(
path
)
else
:
raise
ValueError
(
f
"Unknown Socket Type:
{
socket_type
}
"
)
...
...
@@ -2259,7 +2347,11 @@ def make_zmq_socket(
@
contextlib
.
contextmanager
def
zmq_socket_ctx
(
path
:
str
,
socket_type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
def
zmq_socket_ctx
(
path
:
str
,
socket_type
:
Any
,
linger
:
int
=
0
,
)
->
Iterator
[
zmq
.
Socket
]:
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
...
...
@@ -2270,7 +2362,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger
.
debug
(
"Got Keyboard Interrupt."
)
finally
:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
linger
)
def
is_in_ray_actor
():
...
...
@@ -2567,3 +2659,33 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
return
wrapper
return
decorator
# Only relevant for models using ALiBi (e.g, MPT)
def
check_use_alibi
(
model_config
:
ModelConfig
)
->
bool
:
return
(
getattr
(
model_config
.
hf_text_config
,
"alibi"
,
False
)
# Falcon
or
(
"BloomForCausalLM"
in
getattr
(
model_config
.
hf_config
,
"architectures"
,
[]))
# Bloom
or
getattr
(
model_config
.
hf_text_config
,
"position_encoding_type"
,
""
)
==
"alibi"
# codellm_1b_alibi
or
(
hasattr
(
model_config
.
hf_text_config
,
"attn_config"
)
# MPT
and
model_config
.
hf_text_config
.
attn_config
.
get
(
"alibi"
,
False
)))
def
sha256
(
input
)
->
int
:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
arbitrary Python objects to be used. Note that this function does
not use a hash seed—if you need one, prepend it explicitly to the input.
Args:
input: Any picklable Python object.
Returns:
An integer representing the SHA-256 hash of the serialized input.
"""
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
byteorder
=
"big"
)
vllm/v1/attention/backends/flash_attn.py
View file @
fcfc474d
...
...
@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
tensor
,
page_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
page_size
assert
attn_chunk_size
%
page_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by page_size
{
page_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
()
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
class
FlashAttentionMetadataBuilder
:
...
...
@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
query_start_loc_cpu
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
]
query_start_loc
=
query_start_loc_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
# for local attention
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table
,
self
.
runner
.
block_size
,
)
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
),
local_block_table
=
virt_block_table
,
local_max_query_len
=
seqlens_q_local_np
.
max
(),
local_max_seq_len
=
virt_k_seqlens_np
.
max
(),
)
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
...
...
@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
)
return
attn_metadata
...
...
@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
self
.
use_irope
=
use_irope
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
and
not
flash_attn_supports_fp8
():
...
...
@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
...
...
@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
# Regular attention (common case).
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_actual_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_
len
,
seqused_k
=
attn_metadata
.
seq_lens
,
max_seqlen_k
=
attn_metadata
.
max_seq
_
len
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seq
len
_q
,
seqused_k
=
seqused_k
,
max_seqlen_k
=
max_seqlen
_k
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
...
...
@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
cascade_attention
(
output
[:
num_actual_tokens
],
...
...
vllm/v1/attention/backends/pallas.py
View file @
fcfc474d
...
...
@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK
=
32
NUM_KV_PAGES_PER_BLOCK
=
128
class
PallasAttentionBackend
(
AttentionBackend
):
...
...
@@ -41,7 +37,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
*
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -92,6 +88,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -99,15 +97,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
@@ -118,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
tpu_version
=
torch_xla
.
tpu
.
version
()
if
tpu_version
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if
tpu_version
==
4
:
self
.
vmem_limit_bytes
=
16
*
1024
*
1024
else
:
self
.
vmem_limit_bytes
=
64
*
1024
*
1024
def
forward
(
self
,
...
...
@@ -132,7 +118,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -142,14 +128,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads * head_size])
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if
kv_cache
[
0
]
.
numel
()
==
0
:
if
kv_cache
.
numel
()
==
0
:
if
output
is
None
:
output
=
torch
.
ones_like
(
query
)
return
output
...
...
@@ -158,24 +143,28 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
kv_cache
if
kv_cache
[
0
].
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
write_to_kv_cache
(
key
,
value
,
k
ey_cache
,
value
_cache
,
slot_mapping
)
write_to_kv_cache
(
key
,
value
,
k
v
_cache
,
slot_mapping
)
output
=
torch
.
ops
.
xla
.
ragged_paged_attention
(
query
,
key_cache
,
value_cache
,
kv_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
query_start_loc
,
attn_metadata
.
num_seqs
,
num_kv_pages_per_block
=
NUM_KV_PAGES_PER_BLOCK
,
num_queries_per_block
=
NUM_QUERIES_PER_BLOCK
,
vmem_limit_bytes
=
self
.
vmem_limit_bytes
,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block
=
None
,
num_queries_per_block
=
None
,
vmem_limit_bytes
=
None
,
use_kernel
=
True
,
sm_scale
=
self
.
scale
)
sm_scale
=
self
.
scale
,
sliding_window
=
self
.
sliding_window
,
soft_cap
=
self
.
logits_soft_cap
,
)
return
output
.
reshape
(
num_tokens
,
hidden_size
)
...
...
@@ -183,8 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
""" Write the key and values to the KV cache.
...
...
@@ -192,14 +180,19 @@ def write_to_kv_cache(
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
_
,
_
,
num_combined_kv_heads
,
head_size
=
kv_cache
.
shape
num_kv_heads
=
num_combined_kv_heads
//
2
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
kv
=
torch
.
cat
([
key
,
value
],
axis
=-
1
).
reshape
(
-
1
,
num_combined_kv_heads
,
head_size
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache
,
True
)
key_cache
=
key_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
kv_cache
=
kv_cache
.
flatten
(
0
,
1
)
kv_cache
.
index_copy_
(
0
,
slot_mapping
,
kv
)
vllm/v1/attention/backends/triton_attn.py
View file @
fcfc474d
...
...
@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
...
...
@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
use_irope
=
use_irope
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
@@ -156,19 +158,37 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
)
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
if
use_local_attn
:
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
sequesd_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
sequesd_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
chunked_prefill_paged_decode
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
value
=
value
[:
num_actual_tokens
],
output
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
block_table
=
attn_metadata
.
block_table
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
max_query_len
=
attn_metadata
.
max_query_len
,
block_table
=
block_table
,
query_start_loc
=
cu_seqlens_q
,
seq_lens
=
sequesd_k
,
max_seq_len
=
max_seqlen_k
,
max_query_len
=
max_seqlen_q
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
...
...
vllm/v1/core/block_pool.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
typing
import
Optional
from
typing
import
Callable
,
Optional
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
...
...
@@ -27,6 +27,7 @@ class BlockPool:
"""
def
__init__
(
self
,
num_gpu_blocks
:
int
,
enable_caching
:
bool
):
assert
isinstance
(
num_gpu_blocks
,
int
)
and
num_gpu_blocks
>
0
self
.
num_gpu_blocks
=
num_gpu_blocks
self
.
enable_caching
=
enable_caching
# All kv-cache blocks.
...
...
@@ -50,6 +51,11 @@ class BlockPool:
self
.
cached_block_hash_to_block
:
dict
[
BlockHashType
,
dict
[
int
,
KVCacheBlock
]]
=
defaultdict
(
dict
)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self
.
null_block
=
self
.
free_block_queue
.
popleft
()
def
get_cached_block
(
self
,
block_hash
:
BlockHashType
)
->
Optional
[
KVCacheBlock
]:
"""Get a cached block by the block hash, or None if cache miss.
...
...
@@ -75,6 +81,7 @@ class BlockPool:
num_cached_blocks
:
int
,
num_full_blocks
:
int
,
block_size
:
int
,
hash_fn
:
Callable
,
)
->
None
:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
...
...
@@ -93,6 +100,7 @@ class BlockPool:
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
hash_fn: The hash function to use for block hashes.
"""
if
num_cached_blocks
==
num_full_blocks
:
return
...
...
@@ -138,7 +146,7 @@ class BlockPool:
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_hash
=
hash_block_tokens
(
hash_fn
,
prev_block_hash_value
,
block_tokens
,
extra_keys
)
block_hashes
.
append
(
block_hash
)
...
...
@@ -212,7 +220,7 @@ class BlockPool:
for
block
in
blocks
:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
remove
(
block
)
block
.
incr_ref
()
...
...
@@ -226,7 +234,8 @@ class BlockPool:
"""
for
block
in
ordered_blocks
:
block
.
decr_ref
()
if
block
.
ref_cnt
==
0
:
# null_block should not be added to the free list.
if
block
.
ref_cnt
==
0
and
block
!=
self
.
null_block
:
self
.
free_block_queue
.
append
(
block
)
def
reset_prefix_cache
(
self
)
->
bool
:
...
...
@@ -239,10 +248,10 @@ class BlockPool:
False otherwise.
"""
num_used_blocks
=
(
self
.
num_gpu_blocks
-
self
.
get_num_free_blocks
())
if
num_used_blocks
>
0
:
if
num_used_blocks
!=
1
:
# The null block is always marked as used
logger
.
warning
(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet"
,
num_used_blocks
)
"blocks (%d) are not freed yet"
,
num_used_blocks
-
1
)
return
False
# Remove all hashes so that no new blocks will hit.
...
...
vllm/v1/core/encoder_cache_manager.py
View file @
fcfc474d
...
...
@@ -3,7 +3,7 @@
from
typing
import
TYPE_CHECKING
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
M
ULTIMODAL_REGISTRY
from
vllm.multimodal
import
M
ultiModalRegistry
from
vllm.v1.request
import
Request
if
TYPE_CHECKING
:
...
...
@@ -67,6 +67,7 @@ class EncoderCacheManager:
def
compute_encoder_budget
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
mm_registry
:
MultiModalRegistry
,
)
->
tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
...
...
@@ -74,6 +75,7 @@ def compute_encoder_budget(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
...
...
@@ -89,7 +91,11 @@ def compute_encoder_budget(
(
encoder_compute_budget
,
encoder_cache_size
,
)
=
_compute_encoder_budget_multimodal
(
model_config
,
scheduler_config
)
)
=
_compute_encoder_budget_multimodal
(
model_config
,
scheduler_config
,
mm_registry
,
)
return
encoder_compute_budget
,
encoder_cache_size
...
...
@@ -97,6 +103,7 @@ def compute_encoder_budget(
def
_compute_encoder_budget_multimodal
(
model_config
:
"ModelConfig"
,
scheduler_config
:
"SchedulerConfig"
,
mm_registry
:
MultiModalRegistry
,
)
->
tuple
[
int
,
int
]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
...
...
@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
...
...
@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence.
"""
max_tokens_by_modality_dict
=
MULTIMODAL_REGISTRY
.
get_max_tokens_per_item_by_nonzero_modality
(
# noqa: E501
model_config
)
max_tokens_by_modality_dict
=
mm_registry
\
.
get_max_tokens_per_item_by_nonzero_modality
(
model_config
)
if
not
max_tokens_by_modality_dict
:
logger
.
warning
(
...
...
Prev
1
…
19
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment