Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2039 additions
and
323 deletions
+2039
-323
vllm/tool_parsers/glm4_moe_tool_parser.py
vllm/tool_parsers/glm4_moe_tool_parser.py
+2
-1
vllm/tool_parsers/minimax_m2_tool_parser.py
vllm/tool_parsers/minimax_m2_tool_parser.py
+18
-4
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+38
-1
vllm/utils/mem_utils.py
vllm/utils/mem_utils.py
+27
-4
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+75
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+13
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+403
-269
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+27
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+15
-9
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+6
-2
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+36
-10
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+23
-12
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+40
-2
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+2
-1
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+19
-2
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+6
-5
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+16
-1
vllm/v1/metrics/perf.py
vllm/v1/metrics/perf.py
+1244
-0
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+3
-0
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+26
-0
No files found.
vllm/tool_parsers/glm4_moe_tool_parser.py
View file @
a810671a
...
@@ -114,7 +114,8 @@ class Glm4MoeModelToolParser(ToolParser):
...
@@ -114,7 +114,8 @@ class Glm4MoeModelToolParser(ToolParser):
ToolCall
(
ToolCall
(
type
=
"function"
,
type
=
"function"
,
function
=
FunctionCall
(
function
=
FunctionCall
(
name
=
tc_name
,
arguments
=
json
.
dumps
(
arg_dct
)
name
=
tc_name
,
arguments
=
json
.
dumps
(
arg_dct
,
ensure_ascii
=
False
),
),
),
)
)
)
)
...
...
vllm/tool_parsers/minimax_m2_tool_parser.py
View file @
a810671a
...
@@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser):
self
.
streaming_request
=
None
self
.
streaming_request
=
None
# Clear previous tool call history to avoid state pollution
# Clear previous tool call history to avoid state pollution
self
.
prev_tool_call_arr
.
clear
()
self
.
prev_tool_call_arr
.
clear
()
# Reset streamed args tracking
self
.
streamed_args_for_tool
.
clear
()
def
_extract_name
(
self
,
name_str
:
str
)
->
str
:
def
_extract_name
(
self
,
name_str
:
str
)
->
str
:
"""Extract name from quoted string."""
"""Extract name from quoted string."""
...
@@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser):
self
.
prev_tool_call_arr
.
append
(
self
.
prev_tool_call_arr
.
append
(
{
{
"name"
:
self
.
current_function_name
,
"name"
:
self
.
current_function_name
,
"arguments"
:
"
{}
"
,
# Placeholder, will be updated later
"arguments"
:
{},
# Placeholder, will be updated later
}
}
)
)
# Initialize streamed_args_for_tool for this tool call
if
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Send header with function info
# Send header with function info
return
DeltaMessage
(
return
DeltaMessage
(
...
@@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser):
# Send opening brace if not sent yet
# Send opening brace if not sent yet
if
self
.
in_function
and
not
self
.
json_started
:
if
self
.
in_function
and
not
self
.
json_started
:
self
.
json_started
=
True
self
.
json_started
=
True
# Update streamed_args_for_tool for opening brace
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
"{"
return
DeltaMessage
(
return
DeltaMessage
(
tool_calls
=
[
tool_calls
=
[
DeltaToolCall
(
DeltaToolCall
(
...
@@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser):
args
=
parsed_tool
.
function
.
arguments
args
=
parsed_tool
.
function
.
arguments
self
.
prev_tool_call_arr
[
self
.
current_tool_index
][
self
.
prev_tool_call_arr
[
self
.
current_tool_index
][
"arguments"
"arguments"
]
=
args
]
=
json
.
loads
(
args
)
except
Exception
:
except
Exception
:
pass
# Ignore parsing errors during streaming
pass
# Ignore parsing errors during streaming
...
@@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser):
)
)
]
]
)
)
# Update streamed_args_for_tool for closing brace
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
"}"
# Reset state for next tool
# Reset state for next tool
self
.
json_closed
=
True
self
.
json_closed
=
True
self
.
in_function
=
False
self
.
in_function
=
False
...
@@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser):
...
@@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser):
)
)
self
.
param_count
+=
1
self
.
param_count
+=
1
# Update streamed_args_for_tool for this tool call
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
(
json_fragment
)
return
DeltaMessage
(
return
DeltaMessage
(
tool_calls
=
[
tool_calls
=
[
DeltaToolCall
(
DeltaToolCall
(
...
...
vllm/utils/flashinfer.py
View file @
a810671a
...
@@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
...
@@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
)
)
@
functools
.
cache
def
has_flashinfer_trtllm_fused_moe
()
->
bool
:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if
not
has_flashinfer_moe
():
return
False
required_functions
=
[
(
"flashinfer.fused_moe"
,
"trtllm_fp8_block_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp8_per_tensor_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp4_block_scale_moe"
),
]
for
module_name
,
attr_name
in
required_functions
:
mod
=
_get_submodule
(
module_name
)
if
not
mod
or
not
hasattr
(
mod
,
attr_name
):
return
False
return
True
@
functools
.
cache
@
functools
.
cache
def
has_flashinfer_cutlass_fused_moe
()
->
bool
:
def
has_flashinfer_cutlass_fused_moe
()
->
bool
:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
...
@@ -288,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
...
@@ -288,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if
force_use_trtllm_attention
()
is
False
:
if
force_use_trtllm_attention
()
is
False
:
return
False
return
False
has_trtllm
=
supports_trtllm_attention
()
has_trtllm
=
supports_trtllm_attention
()
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if
has_trtllm
and
num_kv_heads
==
1
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
and
(
num_kv_heads
!=
1
)
def
use_trtllm_attention
(
def
use_trtllm_attention
(
...
@@ -338,6 +366,15 @@ def use_trtllm_attention(
...
@@ -338,6 +366,15 @@ def use_trtllm_attention(
)
)
return
False
return
False
# num_kv_heads=1 is not supported
if
num_kv_heads
==
1
:
if
force_use_trtllm
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return
False
if
has_spec
and
not
is_prefill
:
if
has_spec
and
not
is_prefill
:
# Speculative decoding requires TRTLLM attention for decodes
# Speculative decoding requires TRTLLM attention for decodes
logger
.
info_once
(
"Using TRTLLM attention (enabled for speculative decoding)."
)
logger
.
info_once
(
"Using TRTLLM attention (enabled for speculative decoding)."
)
...
...
vllm/utils/mem_utils.py
View file @
a810671a
...
@@ -66,27 +66,43 @@ class MemorySnapshot:
...
@@ -66,27 +66,43 @@ class MemorySnapshot:
torch_memory
:
int
=
0
torch_memory
:
int
=
0
non_torch_memory
:
int
=
0
non_torch_memory
:
int
=
0
timestamp
:
float
=
0.0
timestamp
:
float
=
0.0
device
:
torch
.
types
.
Device
=
None
auto_measure
:
bool
=
True
auto_measure
:
bool
=
True
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
if
self
.
device
is
None
:
from
vllm.platforms
import
current_platform
device_fn
=
current_platform
.
current_device
assert
device_fn
is
not
None
self
.
device_
=
torch
.
device
(
device_fn
())
else
:
self
.
device_
=
torch
.
device
(
self
.
device
)
if
self
.
auto_measure
:
if
self
.
auto_measure
:
self
.
measure
()
self
.
measure
()
def
measure
(
self
)
->
None
:
def
measure
(
self
)
->
None
:
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
device
=
self
.
device_
# we measure the torch peak memory usage via allocated_bytes,
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
# when we call `torch.cuda.empty_cache()` or OOM happens.
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
(
device
).
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
free_memory
,
self
.
total_memory
=
torch
.
cuda
.
mem_get_info
()
self
.
free_memory
,
self
.
total_memory
=
torch
.
cuda
.
mem_get_info
(
device
)
shared_sysmem_device_mem_sms
=
((
8
,
7
),
(
11
,
0
),
(
12
,
1
))
# Orin, Thor, Spark
shared_sysmem_device_mem_sms
=
((
8
,
7
),
(
11
,
0
),
(
12
,
1
))
# Orin, Thor, Spark
if
(
if
(
current_platform
.
is_cuda
()
current_platform
.
is_cuda
()
and
current_platform
.
get_device_capability
()
in
shared_sysmem_device_mem_sms
and
current_platform
.
get_device_capability
(
device
.
index
)
in
shared_sysmem_device_mem_sms
):
):
# On UMA (Orin, Thor and Spark) platform,
# On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory,
# where both CPU and GPU rely on system memory,
...
@@ -106,12 +122,18 @@ class MemorySnapshot:
...
@@ -106,12 +122,18 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
# this is used to measure the non-torch memory usage
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
()
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
(
device
)
self
.
non_torch_memory
=
self
.
cuda_memory
-
self
.
torch_memory
self
.
non_torch_memory
=
self
.
cuda_memory
-
self
.
torch_memory
self
.
timestamp
=
time
.
time
()
self
.
timestamp
=
time
.
time
()
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
if
self
.
device_
!=
other
.
device_
:
raise
ValueError
(
"The two snapshots should be from the same device! "
f
"Found:
{
self
.
device_
}
vs.
{
other
.
device_
}
"
)
return
MemorySnapshot
(
return
MemorySnapshot
(
torch_peak
=
self
.
torch_peak
-
other
.
torch_peak
,
torch_peak
=
self
.
torch_peak
-
other
.
torch_peak
,
free_memory
=
self
.
free_memory
-
other
.
free_memory
,
free_memory
=
self
.
free_memory
-
other
.
free_memory
,
...
@@ -120,6 +142,7 @@ class MemorySnapshot:
...
@@ -120,6 +142,7 @@ class MemorySnapshot:
torch_memory
=
self
.
torch_memory
-
other
.
torch_memory
,
torch_memory
=
self
.
torch_memory
-
other
.
torch_memory
,
non_torch_memory
=
self
.
non_torch_memory
-
other
.
non_torch_memory
,
non_torch_memory
=
self
.
non_torch_memory
-
other
.
non_torch_memory
,
timestamp
=
self
.
timestamp
-
other
.
timestamp
,
timestamp
=
self
.
timestamp
-
other
.
timestamp
,
device
=
self
.
device_
,
auto_measure
=
False
,
auto_measure
=
False
,
)
)
...
...
vllm/utils/torch_utils.py
View file @
a810671a
...
@@ -24,6 +24,10 @@ else:
...
@@ -24,6 +24,10 @@ else:
ModelConfig
=
object
ModelConfig
=
object
IntermediateTensors
=
object
IntermediateTensors
=
object
import
logging
logger
=
logging
.
getLogger
(
__name__
)
STR_DTYPE_TO_TORCH_DTYPE
=
{
STR_DTYPE_TO_TORCH_DTYPE
=
{
"float32"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
...
@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
...
@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
}
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
=
{
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8"
:
"fp8_e4m3"
,
}
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
...
@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
return
torch_dtype
return
torch_dtype
def
get_kv_cache_quant_algo_string
(
quant_cfg
:
dict
[
str
,
Any
])
->
str
|
None
:
"""Get the KV cache quantization algorithm string from the quantization config.
Maps various FP8 format names to vLLM's standard cache dtype strings.
Returns None if no kv_cache_quant_algo is specified.
Returns "auto" if the value is not recognized/supported.
"""
# Mapping from model config values to vLLM cache_dtype strings
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
)
if
quant_method
.
startswith
(
"modelopt"
):
quantization_inner
=
quant_cfg
.
get
(
"quantization"
,
quant_cfg
)
# Check if quant config is specified and use kv cache quant algo
kv_algo
=
quantization_inner
.
get
(
"kv_cache_quant_algo"
)
or
quant_cfg
.
get
(
"kv_cache_quant_algo"
)
if
isinstance
(
kv_algo
,
str
):
kv_algo_lower
=
kv_algo
.
lower
()
# Try to map to vLLM's standard format
if
kv_algo_lower
in
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
:
return
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
[
kv_algo_lower
]
else
:
# Unknown/unsupported format - return "auto" as safe fallback
logger
.
warning
(
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
"config. Supported values: %s. Falling back to 'auto'."
,
kv_algo
,
list
(
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
.
keys
()),
)
return
"auto"
return
None
def
get_kv_cache_quant_algo_dtype
(
quant_cfg
:
dict
[
str
,
Any
])
->
torch
.
dtype
|
None
:
"""Get the KV cache quantization algorithm dtype from the quantization config."""
kv_algo_str
=
get_kv_cache_quant_algo_string
(
quant_cfg
)
if
kv_algo_str
is
not
None
and
kv_algo_str
!=
"auto"
:
# Only convert if we have a valid dtype string (not "auto" fallback)
return
STR_DTYPE_TO_TORCH_DTYPE
[
kv_algo_str
]
return
None
def
resolve_kv_cache_dtype_string
(
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
)
->
str
:
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
Returns the resolved cache_dtype string.
"""
if
kv_cache_dtype
!=
"auto"
:
return
kv_cache_dtype
hf_cfg
=
getattr
(
model_config
,
"hf_config"
,
None
)
if
hf_cfg
is
not
None
:
quant_cfg
=
getattr
(
hf_cfg
,
"quantization_config"
,
None
)
if
quant_cfg
is
not
None
:
kv_algo_str
=
get_kv_cache_quant_algo_string
(
quant_cfg
)
if
kv_algo_str
is
not
None
:
return
kv_algo_str
# Default to auto (will be handled by downstream code)
return
"auto"
def
kv_cache_dtype_str_to_dtype
(
def
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
)
->
torch
.
dtype
:
)
->
torch
.
dtype
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
a810671a
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
ClassVar
from
typing
import
ClassVar
...
@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
...
@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if
get_flash_attn_version
()
==
3
if
get_flash_attn_version
()
==
3
else
AttentionCGSupport
.
UNIFORM_BATCH
else
AttentionCGSupport
.
UNIFORM_BATCH
)
)
supports_update_block_table
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
...
@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
)
)
return
attn_metadata
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
FlashAttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
FlashAttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
new_metadata
.
block_table
=
blk_table
new_metadata
.
slot_mapping
=
slot_mapping
return
new_metadata
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
use_cascade_attention
(
*
args
,
**
kwargs
)
return
use_cascade_attention
(
*
args
,
**
kwargs
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
a810671a
...
@@ -16,6 +16,7 @@ from flashinfer import (
...
@@ -16,6 +16,7 @@ from flashinfer import (
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.utils
import
FP4Tensor
from
flashinfer.utils
import
FP4Tensor
from
typing_extensions
import
override
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
from
vllm.attention.backends.abstract
import
(
...
@@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
...
@@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.utils
import
CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
...
@@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
...
@@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
paged_kv_indptr_cpu
:
torch
.
Tensor
,
paged_kv_indptr_cpu
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_last_page_len_cpu
:
torch
.
Tensor
,
paged_kv_last_page_len_cpu
:
torch
.
Tensor
,
prefill_start
:
int
,
page_size
:
int
,
page_size
:
int
,
num_qo_heads
:
int
,
num_qo_heads
:
int
,
dcp_world_size
:
int
,
dcp_world_size
:
int
,
...
@@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
...
@@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
qo_indptr_cpu
,
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indices
,
paged_kv_indices
,
paged_kv_last_page_len_cpu
[
prefill_start
:]
,
paged_kv_last_page_len_cpu
,
num_qo_heads
*
dcp_world_size
,
num_qo_heads
*
dcp_world_size
,
num_kv_heads
,
num_kv_heads
,
head_dim
,
head_dim
,
...
@@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
...
@@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
@
dataclass
@
dataclass
class
F
lashInferMetadata
:
class
F
IPrefill
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
"""Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
# The data type of the query
wrapper
:
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
q_data_type
:
torch
.
dtype
slot_mapping
:
torch
.
Tensor
# For flashinfer trtllm batch decode
@
dataclass
class
FIDecode
:
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
wrapper
:
BatchDecodeWithPagedKVCacheWrapper
@
dataclass
class
TRTLLMPrefill
:
"""Metadata for the TRTLLM prefill pathway."""
block_tables
:
torch
.
Tensor
"""
The slice of the block table tensor corresponding *only* to prefill requests.
Shape: [num_prefills, max_num_blocks_per_seq]
"""
seq_lens
:
torch
.
Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
Shape: [num_prefills]
"""
cum_seq_lens_q
:
torch
.
Tensor
cum_seq_lens_kv
:
torch
.
Tensor
max_q_len
:
int
max_q_len
:
int
max_q_len_prefill
:
int
"""
The maximum query length *among prefill requests*.
"""
max_seq_len
:
int
max_seq_len
:
int
"""The maximum sequence length for KV Cache."""
@
dataclass
class
TRTLLMDecode
:
"""Metadata for the TRTLLM decode pathway."""
block_tables
:
torch
.
Tensor
"""
The slice of the block table tensor corresponding *only* to decode requests.
Shape: [num_decodes, max_num_blocks_per_seq]
"""
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
block_table_tensor
:
torch
.
Tensor
"""
prefill_use_trtllm
:
bool
The slice of the sequence lengths tensor corresponding *only* to decode requests.
decode_use_trtllm
:
bool
Shape: [num_decodes]
"""
max_seq_len
:
int
"""The maximum sequence length for KV Cache."""
@
dataclass
class
FlashInferMetadata
:
num_actual_tokens
:
int
"""Total number of tokens in the batch (excluding padding)."""
slot_mapping
:
torch
.
Tensor
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
q_data_type
:
torch
.
dtype
# For handling prefill decode split
num_decodes
:
int
num_decodes
:
int
num_decode_tokens
:
int
num_decode_tokens
:
int
num_prefills
:
int
num_prefills
:
int
num_prefill_tokens
:
int
num_prefill_tokens
:
int
# For cascade attention (CPU for planning).
prefill
:
FIPrefill
|
TRTLLMPrefill
|
None
use_cascade
:
bool
"""
Holds the metadata for the prefill portion of the batch.
Will be `None` if `num_prefill_tokens == 0`.
"""
decode
:
FIDecode
|
TRTLLMDecode
|
None
"""
Holds the metadata for the decode portion of the batch.
Will be `None` if `num_decode_tokens == 0`.
"""
prefill_wrapper
:
(
# --- Special Case: Cascade Attention ---
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
|
None
)
=
None
decode_wrapper
:
BatchDecodeWithPagedKVCacheWrapper
|
None
=
None
cascade_wrapper
:
MultiLevelCascadeAttentionWrapper
|
None
=
None
qo_indptr_gpu
:
torch
.
Tensor
|
None
=
None
use_cascade
:
bool
paged_kv_indptr_gpu
:
torch
.
Tensor
|
None
=
None
"""
If True, the entire batch is a cascade attention call, and the
`prefill` and `decode` fields will both be None.
"""
cascade_wrapper
:
MultiLevelCascadeAttentionWrapper
|
None
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
...
@@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
dcp_world_size
=
1
self
.
dcp_world_size
=
1
self
.
dcp_rank
=
0
self
.
dcp_rank
=
0
self
.
dcp_kv_cache_interleave_size
=
1
self
.
dcp_kv_cache_interleave_size
=
1
self
.
use_dcp
=
self
.
dcp_world_size
>
1
self
.
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
self
.
vllm_config
.
parallel_config
...
@@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"sinks, please use trtllm on blackwell or flash attention on "
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs."
"earlier GPUs."
)
)
# Preparing persistent buffers (device-side)
# Preparing persistent buffers
self
.
paged_kv_indptr
=
torch
.
zeros
(
self
.
pin_memory
=
is_pin_memory_available
()
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
self
.
paged_kv_indptr
=
self
.
_make_buffer
(
max_num_reqs
+
1
)
)
self
.
paged_kv_indptr_cpu_buffer
=
torch
.
zeros_like
(
self
.
paged_kv_indices
=
torch
.
zeros
(
self
.
paged_kv_indptr
.
cpu
,
pin_memory
=
self
.
pin_memory
max_num_pages
,
# max num pages possible
)
# Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
dtype
=
torch
.
int32
,
self
.
paged_kv_indices
=
self
.
_make_buffer
(
max_num_pages
)
device
=
self
.
device
,
self
.
paged_kv_last_page_len
=
self
.
_make_buffer
(
max_num_reqs
)
)
self
.
paged_kv_last_page_len
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# host-side buffer
pin_memory
=
is_pin_memory_available
()
self
.
paged_kv_indptr_cpu
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indptr_np
=
self
.
paged_kv_indptr_cpu
.
numpy
()
self
.
paged_kv_indptr_buffer
=
torch
.
zeros_like
(
self
.
paged_kv_indptr_cpu
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indices_cpu
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_last_page_len_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_last_page_len_np
=
self
.
paged_kv_last_page_len_cpu
.
numpy
()
if
self
.
head_dim
==
256
and
current_platform
.
is_device_capability_family
(
100
):
if
self
.
head_dim
==
256
and
current_platform
.
is_device_capability_family
(
100
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
...
@@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64."
"passing --block-size 32 or --block-size 64."
)
)
def
_make_buffer
(
self
,
*
size
:
int
|
torch
.
SymInt
,
dtype
:
torch
.
dtype
=
torch
.
int32
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
*
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
with_numpy
=
True
,
)
@
override
# type: ignore[misc]
@
classmethod
@
classmethod
def
get_cudagraph_support
(
def
get_cudagraph_support
(
cls
:
type
[
"FlashInferMetadataBuilder"
],
cls
:
type
[
"FlashInferMetadataBuilder"
],
...
@@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
,
self
,
)
->
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
:
)
->
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
:
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
if
self
.
dcp_world_size
>
1
:
if
self
.
use_dcp
:
self
.
_prefill_wrapper
=
BatchDCPPrefillWrapper
(
self
.
_prefill_wrapper
=
BatchDCPPrefillWrapper
(
workspace_buffer
=
self
.
_get_workspace_buffer
(),
workspace_buffer
=
self
.
_get_workspace_buffer
(),
)
)
...
@@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
decode_wrapper
is
None
:
if
decode_wrapper
is
None
:
if
use_cudagraph
:
if
use_cudagraph
:
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
batch_size
+
1
]
paged_kv_indptr
=
self
.
paged_kv_indptr
.
gpu
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
paged_kv_indices
paged_kv_indices
=
self
.
paged_kv_indices
.
gpu
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
batch_size
]
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
gpu
[:
batch_size
]
else
:
else
:
paged_kv_indptr
=
None
paged_kv_indptr
=
None
paged_kv_indices
=
None
paged_kv_indices
=
None
...
@@ -661,99 +718,43 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -661,99 +718,43 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
)
return
self
.
_cascade_wrapper
return
self
.
_cascade_wrapper
def
build
(
def
_compute_flashinfer_kv_metadata
(
self
,
self
,
common_prefix_len
:
int
,
num_blocks_np
:
np
.
ndarray
,
common_attn_metadata
:
CommonAttentionMetadata
,
seq_lens_np
:
np
.
ndarray
,
fast_build
:
bool
=
False
,
block_table_tensor
:
torch
.
Tensor
,
)
->
FlashInferMetadata
:
num_reqs
:
int
,
num_reqs
=
common_attn_metadata
.
num_reqs
page_size
:
int
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
)
->
torch
.
Tensor
:
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
"""
split_decodes_and_prefills
(
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
common_attn_metadata
,
attention.
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
page_size
=
self
.
page_size
max_q_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
common_attn_metadata
.
max_seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
if
self
.
dcp_world_size
>
1
:
if
num_prefills
>
0
:
qo_indptr_prefill_cpu
=
(
qo_indptr_cpu
[
num_decodes
:]
-
qo_indptr_cpu
[
num_decodes
]
)
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
seq_lens_cpu
[
num_decodes
:]
=
(
seq_lens_cpu
[
num_decodes
:]
-
query_lens_prefill_cpu
)
seq_lens_cpu
=
get_dcp_local_seq_lens
(
seq_lens_cpu
,
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
dcp_kv_cache_interleave_size
,
)
seq_lens_np
=
seq_lens_cpu
.
numpy
()
num_blocks_np
=
(
seq_lens_np
+
(
page_size
-
1
))
//
page_size
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# Grab the blocks of the shared prefix from the first request.
assert
common_prefix_len
%
page_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indices_cpu
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len_cpu
=
torch
.
tensor
(
[
page_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# Remove the blocks of the shared prefix from all requests.
Results are stored in self.paged_kv_indptr,
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
num_blocks_np
-=
num_common_kv_blocks
else
:
shared_qo_indptr_cpu
=
None
shared_kv_page_indptr_cpu
=
None
shared_kv_page_indices_cpu
=
None
shared_kv_last_page_len_cpu
=
None
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
"""
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np
.
cumsum
(
np
.
cumsum
(
num_blocks_np
,
num_blocks_np
,
dtype
=
np
.
int32
,
dtype
=
np
.
int32
,
out
=
self
.
paged_kv_indptr
_
np
[
1
:
num_reqs
+
1
],
out
=
self
.
paged_kv_indptr
.
np
[
1
:
num_reqs
+
1
],
)
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
# self.paged_kv_indptr_buffer to avoid race condition.
self
.
paged_kv_indptr_buffer
[:
num_reqs
+
1
]
=
self
.
paged_kv_indptr
_
cpu
[
self
.
paged_kv_indptr_
cpu_
buffer
[:
num_reqs
+
1
]
=
self
.
paged_kv_indptr
.
cpu
[
:
num_reqs
+
1
:
num_reqs
+
1
]
]
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
num_reqs
+
1
]
paged_kv_indptr
=
self
.
paged_kv_indptr
.
gpu
[:
num_reqs
+
1
]
paged_kv_indptr
.
copy_
(
paged_kv_indptr
.
copy_
(
self
.
paged_kv_indptr_buffer
[:
num_reqs
+
1
],
non_blocking
=
True
self
.
paged_kv_indptr_
cpu_
buffer
[:
num_reqs
+
1
],
non_blocking
=
True
)
)
# write self.paged_kv_indices inplace
# write self.paged_kv_indices inplace
num_actual_pages
=
self
.
paged_kv_indptr
_
np
[
num_reqs
]
num_actual_pages
=
self
.
paged_kv_indptr
.
np
[
num_reqs
]
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_actual_pages
]
paged_kv_indices
=
self
.
paged_kv_indices
.
gpu
[:
num_actual_pages
]
_copy_page_indices_kernel
[(
num_reqs
,)](
_copy_page_indices_kernel
[(
num_reqs
,)](
paged_kv_indices
,
paged_kv_indices
,
block_table_tensor
,
block_table_tensor
,
...
@@ -764,12 +765,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -764,12 +765,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# write self.paged_kv_last_page_len_cpu inplace
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np
=
seq_lens_np
%
page_size
paged_kv_last_page_len_np
=
seq_lens_np
%
page_size
self
.
paged_kv_last_page_len
_
np
[:
num_reqs
]
=
np
.
where
(
self
.
paged_kv_last_page_len
.
np
[:
num_reqs
]
=
np
.
where
(
(
paged_kv_last_page_len_np
==
0
)
&
(
seq_lens_np
!=
0
),
(
paged_kv_last_page_len_np
==
0
)
&
(
seq_lens_np
!=
0
),
page_size
,
page_size
,
paged_kv_last_page_len_np
,
paged_kv_last_page_len_np
,
)
)
return
paged_kv_indices
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
FlashInferMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
page_size
=
self
.
page_size
max_seq_len
=
common_attn_metadata
.
max_seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
qo_indptr
=
common_attn_metadata
.
query_start_loc
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
# Step 1: Decide which dispatch modes to use:
# - Cascade attention (distinct mode)
# - Prefill (FI native or TRTLLM)
# - Decode (FI native or TRTLLM)
use_cascade
=
common_prefix_len
>
0
uses_spec_reorder
=
self
.
reorder_batch_threshold
>
1
uses_spec_reorder
=
self
.
reorder_batch_threshold
>
1
prefill_use_trtllm
=
use_trtllm_attention
(
prefill_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
self
.
num_qo_heads
,
...
@@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
use_trtllm_decode_attention
and
self
.
dcp_world_size
<=
1
self
.
use_trtllm_decode_attention
and
self
.
dcp_world_size
<=
1
)
)
if
not
(
prefill_use_trtllm
and
decode_use_trtllm
):
all_uses_trtllm
=
(
num_prefills
==
0
or
prefill_use_trtllm
)
and
(
num_decodes
==
0
or
decode_use_trtllm
)
is_only_trtllm_decode
=
num_prefills
==
0
and
(
num_decodes
>
0
and
decode_use_trtllm
)
if
not
all_uses_trtllm
:
if
self
.
has_sinks
:
if
self
.
has_sinks
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"FlashInfer backend currently does not support attention "
"FlashInfer backend currently does not support attention "
...
@@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# fall back to model dtype.
# fall back to model dtype.
self
.
q_data_type
=
self
.
model_config
.
dtype
self
.
q_data_type
=
self
.
model_config
.
dtype
# Step 2: Initialize the output metadata
# Leave prefill/decode/cascade_wrapper empty, to be populated
# case by case depending on the batch contents and backend selection.
attn_metadata
=
FlashInferMetadata
(
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
q_data_type
=
self
.
q_data_type
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
max_q_len
=
max_q_len
,
q_data_type
=
self
.
q_data_type
,
max_q_len_prefill
=
max_q_len
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table_tensor
=
block_table_tensor
,
prefill_use_trtllm
=
prefill_use_trtllm
,
decode_use_trtllm
=
decode_use_trtllm
,
num_decodes
=
num_decodes
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
prefill
=
None
,
decode
=
None
,
cascade_wrapper
=
None
,
)
)
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr_cpu
[:
1
+
num_reqs
]
# Guard access to seq_lens_cpu, which may not always be needed
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
]
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu
=
self
.
use_dcp
or
use_cascade
or
not
is_only_trtllm_decode
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
if
needs_seq_lens_cpu
else
None
seq_lens_np
=
seq_lens_cpu
.
numpy
()
if
seq_lens_cpu
is
not
None
else
None
num_blocks_np
=
(
(
seq_lens_np
+
(
page_size
-
1
))
//
page_size
if
seq_lens_np
is
not
None
else
None
)
# Adjust seq_lens_cpu for DCP
if
self
.
use_dcp
:
assert
seq_lens_cpu
is
not
None
if
num_prefills
>
0
:
qo_indptr_prefill_cpu
=
(
qo_indptr_cpu
[
num_decodes
:]
-
qo_indptr_cpu
[
num_decodes
]
)
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
seq_lens_cpu
[
num_decodes
:]
=
(
seq_lens_cpu
[
num_decodes
:]
-
query_lens_prefill_cpu
)
seq_lens_cpu
=
get_dcp_local_seq_lens
(
seq_lens_cpu
,
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
dcp_kv_cache_interleave_size
,
)
# Adjust num_block_np for cascade attention
if
use_cascade
:
assert
num_blocks_np
is
not
None
assert
common_prefix_len
%
page_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
page_size
num_blocks_np
-=
num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices
=
use_cascade
or
not
is_only_trtllm_decode
if
needs_paged_kv_indices
:
assert
num_blocks_np
is
not
None
assert
seq_lens_np
is
not
None
paged_kv_indices
=
self
.
_compute_flashinfer_kv_metadata
(
num_blocks_np
,
seq_lens_np
,
block_table_tensor
,
num_reqs
,
page_size
,
)
else
:
paged_kv_indices
=
None
# Early-out for cascade attention
if
use_cascade
:
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks
=
common_prefix_len
//
page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indices_cpu
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len_cpu
=
torch
.
tensor
(
[
page_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
num_blocks_np
-=
num_common_kv_blocks
assert
paged_kv_indices
is
not
None
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr
.
cpu
[:
1
+
num_reqs
]
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len
.
cpu
[:
num_reqs
]
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
attn_metadata
.
cascade_wrapper
.
plan
(
[
shared_qo_indptr_cpu
,
qo_indptr_cpu
],
[
shared_qo_indptr_cpu
,
qo_indptr_cpu
],
...
@@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
)
else
:
return
attn_metadata
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
# Step 3: Handle prefill and decode pathways case by case
num_prefills
=
attn_metadata
.
num_prefills
## PREFILL PATHWAY
num_decodes
=
attn_metadata
.
num_decodes
if
num_prefills
>
0
:
if
num_prefills
>
0
:
# Slices for shared prefill metadata
# Decodes are first so prefills start after the last decode
prefill_start
=
num_decodes
prefill_start
=
num_decodes
qo_indptr_prefill_cpu
=
(
attn_metadata
.
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
qo_indptr_cpu
[
prefill_start
:]
-
qo_indptr_cpu
[
prefill_start
]
assert
qo_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
)
assert
paged_kv_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
qo_indptr_prefill_cpu
.
shape
[
0
]
==
num_prefills
+
1
assert
(
paged_kv_last_page_len_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
if
prefill_use_trtllm
:
# Create GPU versions
qo_indptr_prefill_gpu
=
(
qo_indptr
[
prefill_start
:]
-
qo_indptr
[
prefill_start
]
)
paged_kv_indptr_prefill_gpu
=
self
.
paged_kv_indptr
.
gpu
[
prefill_start
:
num_reqs
+
1
]
# Compute max_q_len for prefill requests
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
)
# Since prefill_wrapper.run() will be called with
max_q_len_prefill
=
int
(
query_lens_prefill_cpu
.
max
().
item
())
# query[num_decode_tokens:] we need to adjust the qo_indptr
attn_metadata
.
prefill
=
TRTLLMPrefill
(
# to be relative to the start of the prefill queries.
block_tables
=
block_table_tensor
[
prefill_start
:],
qo_indptr_cpu
=
(
seq_lens
=
seq_lens
[
prefill_start
:],
qo_indptr_cpu
[
prefill_start
:]
-
qo_indptr_cpu
[
prefill_start
]
cum_seq_lens_q
=
qo_indptr_prefill_gpu
,
cum_seq_lens_kv
=
paged_kv_indptr_prefill_gpu
,
max_q_len
=
max_q_len_prefill
,
max_seq_len
=
max_seq_len
,
)
)
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
[
prefill_start
:]
else
:
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
# Recompute max_q_len for the slice of requests we are using
# Slicing CPU buffers that are only needed for FI native prefills
# for prefills. This can be different from max_q_len when
paged_kv_last_page_len_prefill_cpu
=
self
.
paged_kv_last_page_len
.
cpu
[
# we have a non-uniform batch with some short decodes offloaded
prefill_start
:
num_reqs
# to the prefill pathway
]
query_lens_prefill
=
qo_indptr_cpu
[
1
:]
-
qo_indptr_cpu
[:
-
1
]
assert
paged_kv_last_page_len_prefill_cpu
.
shape
[
0
]
==
num_prefills
attn_metadata
.
max_q_len_prefill
=
int
(
query_lens_prefill
.
max
().
item
())
paged_kv_indptr_prefill_cpu
=
self
.
paged_kv_indptr
.
cpu
[
prefill_start
:
num_reqs
+
1
if
not
attn_metadata
.
prefill_use_trtllm
:
]
if
self
.
dcp_world_size
>
1
:
assert
paged_kv_indptr_prefill_cpu
.
shape
[
0
]
==
num_prefills
+
1
assert
isinstance
(
if
self
.
use_dcp
:
attn_metadata
.
prefill_wrapper
,
BatchDCPPrefillWrapper
assert
isinstance
(
prefill_wrapper
,
BatchDCPPrefillWrapper
)
)
prefill_wrapper
.
plan
(
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
=
qo_indptr_prefill_cpu
,
qo_indptr_cpu
=
qo_indptr_cpu
,
paged_kv_indptr_cpu
=
paged_kv_indptr_prefill_cpu
,
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
paged_kv_last_page_len_prefill_cpu
,
paged_kv_last_page_len_cpu
=
paged_kv_last_page_len_cpu
,
page_size
=
self
.
page_size
,
prefill_start
=
prefill_start
,
num_qo_heads
=
self
.
num_qo_heads
,
page_size
=
self
.
page_size
,
dcp_world_size
=
self
.
dcp_world_size
,
num_qo_heads
=
self
.
num_qo_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
dcp_world_size
=
self
.
dcp_world_size
,
head_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
,
sm_scale
=
self
.
sm_scale
,
head_dim
=
self
.
head_dim
,
window_left
=
self
.
window_left
,
sm_scale
=
self
.
sm_scale
,
logits_soft_cap
=
self
.
logits_soft_cap
,
window_left
=
self
.
window_left
,
q_data_type
=
self
.
q_data_type
,
logits_soft_cap
=
self
.
logits_soft_cap
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
q_data_type
=
self
.
q_data_type
,
prefill_fixed_split_size
=
self
.
prefill_fixed_split_size
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
disable_split_kv
=
self
.
disable_split_kv
,
prefill_fixed_split_size
=
self
.
prefill_fixed_split_size
,
)
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
assert
isinstance
(
attn_metadata
.
prefill_wrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_cpu
[
prefill_start
:],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
assert
isinstance
(
self
.
device
,
non_blocking
=
True
prefill_wrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
)
attn_metadata
.
paged_kv_indptr_gpu
=
paged_kv_indptr_cpu
.
to
(
prefill_wrapper
.
plan
(
self
.
device
,
non_blocking
=
True
qo_indptr_prefill_cpu
,
paged_kv_indptr_prefill_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_prefill_cpu
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
)
attn_metadata
.
prefill
=
FIPrefill
(
wrapper
=
prefill_wrapper
)
if
num_decodes
>
0
:
## DECODE PATHWAY
if
num_decodes
>
0
:
if
decode_use_trtllm
:
assert
num_decode_tokens
%
num_decodes
==
0
,
(
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata
.
decode
=
TRTLLMDecode
(
block_tables
=
block_table_tensor
[:
num_decodes
],
seq_lens
=
seq_lens
[:
num_decodes
],
max_seq_len
=
max_seq_len
,
)
else
:
pure_decode
=
num_prefills
==
0
pure_decode
=
num_prefills
==
0
use_cudagraph
=
(
use_cudagraph
=
(
self
.
enable_cuda_graph
self
.
enable_cuda_graph
...
@@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
)
num_input_tokens
=
num_decode_tokens
num_input_tokens
=
num_decode_tokens
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
(
decode_wrapper
=
self
.
_get_decode_wrapper
(
num_input_tokens
,
use_cudagraph
num_input_tokens
,
use_cudagraph
)
)
if
not
attn_metadata
.
decode_use_trtllm
:
# Use the persistent buffer with padding length,
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# instead of the sa
me
ad
dress but chunked version
# in atten_
me
t
ad
ata when using cudagraph.
# in atten_metadata when using cudagraph.
fast_plan_decode
(
fast_plan_decode
(
decode_wrapper
,
attn_metadata
.
decode_wrapper
,
self
.
paged_kv_indptr
.
cpu
[:
num_input_tokens
+
1
]
,
self
.
paged_kv_ind
ptr_cpu
[:
num_input_tokens
+
1
]
,
paged_kv_ind
ices
,
paged_kv_
indices
,
self
.
paged_kv_
last_page_len
.
cpu
[:
num_input_tokens
]
,
self
.
paged_kv_last_page
_len_cpu
[:
num_input_tokens
],
seq
_len
s
_cpu
[:
num_input_tokens
],
seq_lens_cpu
[:
num_input_tokens
]
,
self
.
num_qo_heads
*
self
.
dcp_world_size
,
self
.
num_
qo
_heads
*
self
.
dcp_world_size
,
self
.
num_
kv
_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
page_size
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's
pos
encoding
and use vllm's rope.
pos
_
encoding
_mode
=
"NONE"
,
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
sm_scale
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
q
_data_type
=
self
.
q_data_
type
,
kv
_data_type
=
self
.
kv_cache_d
type
,
kv_data_type
=
self
.
kv_cache_dtyp
e
,
fixed_split_size
=
self
.
decode_fixed_split_siz
e
,
fixed
_split_
size
=
self
.
d
ecode_fixed
_split_
size
,
disable
_split_
kv
=
self
.
d
isable
_split_
kv
,
disable_split_kv
=
self
.
disable_split_kv
,
)
)
attn_metadata
.
decode
=
FIDecode
(
wrapper
=
decode_wrapper
)
return
attn_metadata
return
attn_metadata
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
...
@@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
if
self
.
bmm2_scale
is
None
:
if
self
.
bmm2_scale
is
None
:
self
.
bmm2_scale
=
layer
.
_v_scale_float
self
.
bmm2_scale
=
layer
.
_v_scale_float
prefill_use_trtllm
=
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
decode_use_trtllm
=
isinstance
(
attn_metadata
.
decode
,
TRTLLMDecode
)
# The attn+quant fusion happens when output_scale is provided.
# The attn+quant fusion happens when output_scale is provided.
if
output_scale
is
None
:
if
output_scale
is
None
:
assert
output_block_scale
is
None
,
(
assert
output_block_scale
is
None
,
(
...
@@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
assert
attn_metadata
.
q_data_type
==
FP8_DTYPE
,
(
assert
attn_metadata
.
q_data_type
==
FP8_DTYPE
,
(
"Query must be FP8 when attn+quant fusion happened."
"Query must be FP8 when attn+quant fusion happened."
)
)
assert
(
assert
(
attn_metadata
.
num_prefills
==
0
or
prefill_use_trtllm
)
and
(
attn_metadata
.
prefill_use_trtllm
and
attn_metadata
.
decode_use_trtllm
attn_metadata
.
num_decodes
==
0
or
decode_use_trtllm
),
"Must use TRT-LLM attn"
),
"Must use TRT-LLM attn"
if
output
.
dtype
==
FP8_DTYPE
:
if
output
.
dtype
==
FP8_DTYPE
:
...
@@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
# When using spec decoding, num_decodes can be < num_decode_tokens
# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
# because some decode requests may have more than one query token.
num_decodes
=
attn_metadata
.
num_decodes
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
use_dcp
=
self
.
dcp_world_size
>
1
# Regular attention (common case).
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
# Decodes are at the front and prefills are at the back.
if
num_prefill_tokens
>
0
:
if
num_prefill_tokens
>
0
:
prefill_wrapper
=
attn_metadata
.
prefill_wrapper
prefill_query
=
query
[
num_decode_tokens
:]
prefill_query
=
query
[
num_decode_tokens
:]
assert
prefill_query
.
shape
[
0
]
==
num_prefill_tokens
assert
prefill_query
.
shape
[
0
]
==
num_prefill_tokens
assert
prefill_wrapper
is
not
None
if
not
attn_metadata
.
prefill_use_trtllm
:
if
not
prefill_use_trtllm
:
if
self
.
dcp_world_size
>
1
:
assert
isinstance
(
attn_metadata
.
prefill
,
FIPrefill
)
prefill_wrapper
=
attn_metadata
.
prefill
.
wrapper
assert
prefill_wrapper
is
not
None
if
use_dcp
:
assert
isinstance
(
prefill_wrapper
,
BatchDCPPrefillWrapper
)
assert
isinstance
(
prefill_wrapper
,
BatchDCPPrefillWrapper
)
assert
prefill_wrapper
.
_context
.
_window_left
==
self
.
window_left
assert
prefill_wrapper
.
_context
.
_window_left
==
self
.
window_left
assert
prefill_wrapper
.
_context
.
_logits_soft_cap
==
(
assert
prefill_wrapper
.
_context
.
_logits_soft_cap
==
(
...
@@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
out
=
output
[
num_decode_tokens
:],
out
=
output
[
num_decode_tokens
:],
)
)
else
:
else
:
assert
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
# prefill_query may be non-contiguous
# prefill_query may be non-contiguous
prefill_query
=
prefill_query
.
contiguous
()
prefill_query
=
prefill_query
.
contiguous
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
block_tables_prefill
=
attn_metadata
.
block_table
_tensor
[
num_decodes
:]
block_tables_prefill
=
attn_metadata
.
prefill
.
block_table
s
seq_lens_prefill
=
attn_metadata
.
seq_lens
[
num_decodes
:]
seq_lens_prefill
=
attn_metadata
.
prefill
.
seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert
get_kv_cache_layout
()
==
"HND"
assert
get_kv_cache_layout
()
==
"HND"
...
@@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer
=
workspace_buffer
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
mock_block_table
,
block_tables
=
mock_block_table
,
seq_lens
=
seq_lens_prefill
,
seq_lens
=
seq_lens_prefill
,
max_q_len
=
attn_metadata
.
max_q_len
_prefill
,
max_q_len
=
attn_metadata
.
prefill
.
max_q_len
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
max_kv_len
=
attn_metadata
.
prefill
.
max_seq_len
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
batch_size
=
attn_metadata
.
num_prefills
,
batch_size
=
attn_metadata
.
num_prefills
,
cum_seq_lens_q
=
attn_metadata
.
qo_indptr_gpu
,
cum_seq_lens_q
=
attn_metadata
.
prefill
.
cum_seq_lens_q
,
cum_seq_lens_kv
=
attn_metadata
.
p
aged_kv_indptr_gpu
,
cum_seq_lens_kv
=
attn_metadata
.
p
refill
.
cum_seq_lens_kv
,
window_left
=
self
.
window_left
,
window_left
=
self
.
window_left
,
sinks
=
self
.
sinks
,
sinks
=
self
.
sinks
,
o_sf_scale
=
self
.
o_sf_scale
,
o_sf_scale
=
self
.
o_sf_scale
,
...
@@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
)
)
if
num_decode_tokens
>
0
:
if
num_decode_tokens
>
0
:
decode_wrapper
=
attn_metadata
.
decode_wrapper
decode_query
=
query
[:
num_decode_tokens
]
decode_query
=
query
[:
num_decode_tokens
]
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_wrapper
is
not
None
if
not
attn_metadata
.
decode_use_trtllm
:
if
not
decode_use_trtllm
:
assert
isinstance
(
attn_metadata
.
decode
,
FIDecode
)
decode_wrapper
=
attn_metadata
.
decode
.
wrapper
assert
decode_wrapper
is
not
None
assert
decode_wrapper
.
_window_left
==
self
.
window_left
assert
decode_wrapper
.
_window_left
==
self
.
window_left
assert
decode_wrapper
.
_logits_soft_cap
==
(
self
.
logits_soft_cap
or
0.0
)
assert
decode_wrapper
.
_logits_soft_cap
==
(
self
.
logits_soft_cap
or
0.0
)
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
if
se
lf
.
dcp_world_size
>
1
:
if
u
se
_dcp
:
decode_query
=
get_dcp_group
().
all_gather
(
decode_query
=
get_dcp_group
().
all_gather
(
decode_query
.
contiguous
(),
dim
=-
2
decode_query
.
contiguous
(),
dim
=-
2
)
)
...
@@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
)
)
else
:
else
:
# decode_query may be non-contiguous
# decode_query may be non-contiguous
assert
isinstance
(
attn_metadata
.
decode
,
TRTLLMDecode
)
decode_query
=
decode_query
.
contiguous
()
decode_query
=
decode_query
.
contiguous
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
block_tables_decode
=
attn_metadata
.
block_table_tensor
[
block_tables_decode
=
attn_metadata
.
decode
.
block_tables
:
num_decode_tokens
seq_lens_decode
=
attn_metadata
.
decode
.
seq_lens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert
get_kv_cache_layout
()
==
"HND"
assert
get_kv_cache_layout
()
==
"HND"
...
@@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer
=
workspace_buffer
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables_decode
,
block_tables
=
block_tables_decode
,
seq_lens
=
seq_lens_decode
,
seq_lens
=
seq_lens_decode
,
max_seq_len
=
attn_metadata
.
max_seq_len
,
max_seq_len
=
attn_metadata
.
decode
.
max_seq_len
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
window_left
=
self
.
window_left
,
window_left
=
self
.
window_left
,
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
itertools
import
itertools
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
...
@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class
Mamba2AttentionMetadataBuilder
(
class
Mamba2AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
):
):
supports_update_block_table
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
kv_cache_spec
:
AttentionSpec
,
kv_cache_spec
:
AttentionSpec
,
...
@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
...
@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p
=
num_computed_tokens_p
,
num_computed_tokens_p
=
num_computed_tokens_p
,
)
)
return
attn_metadata
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
Mamba2AttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
Mamba2AttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
a810671a
...
@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
...
@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
...
@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr
:
torch
.
Tensor
|
None
=
None
qo_indptr
:
torch
.
Tensor
|
None
=
None
# The dtype of MLA out tensor
# The dtype of MLA out tensor
attn_out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
attn_out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
# The max query output length: int
max_qo_len
:
int
|
None
=
None
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
...
@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
...
@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
# TODO(luka, lucas): audit this as part of:
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
UNIFORM
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
)
self
.
qo_indptr
=
torch
.
arange
(
self
.
qo_indptr
=
torch
.
zeros
(
0
,
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
)
def
_build_decode
(
def
_build_decode
(
...
@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
),
seq_lens_device
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
),
]
]
)
)
qo_len
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
max_qo_len
=
qo_len
.
max
().
item
()
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
num_actual_pages
=
paged_kv_indices
.
size
(
0
)
num_actual_pages
=
paged_kv_indices
.
size
(
0
)
...
@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_last_page_len
[
num_reqs
:].
fill_
(
1
)
self
.
paged_kv_last_page_len
[
num_reqs
:].
fill_
(
1
)
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
self
.
qo_indptr
[:
1
+
num_reqs
].
copy_
(
query_start_loc_device
,
non_blocking
=
True
)
self
.
qo_indptr
[
1
+
num_reqs
:]
=
query_start_loc_device
[
-
1
]
qo_indptr
=
self
.
qo_indptr
[:
1
+
num_reqs
]
qo_indptr
=
self
.
qo_indptr
[:
1
+
num_reqs
]
else
:
else
:
...
@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
...
@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len
=
paged_kv_last_page_len
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
qo_indptr
=
qo_indptr
,
qo_indptr
=
qo_indptr
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
max_qo_len
=
max_qo_len
,
attn_out_dtype
=
self
.
decode_attn_out_dtype
,
attn_out_dtype
=
self
.
decode_attn_out_dtype
,
)
)
...
@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo
=
1
rocm_aiter_ops
.
mla_decode_fwd
(
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
q
,
kv_buffer
,
kv_buffer
,
o
,
o
,
self
.
scale
,
self
.
scale
,
attn_metadata
.
decode
.
qo_indptr
,
attn_metadata
.
decode
.
qo_indptr
,
max_seqlen_qo
,
attn_metadata
.
decode
.
max_qo_len
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_last_page_len
,
attn_metadata
.
decode
.
paged_kv_last_page_len
,
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
a810671a
...
@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
...
@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class
RocmAttentionBackend
(
AttentionBackend
):
class
RocmAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
]
@
classmethod
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
@@ -165,7 +169,7 @@ class RocmAttentionBackend(AttentionBackend):
...
@@ -165,7 +169,7 @@ class RocmAttentionBackend(AttentionBackend):
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
"Set --attention-
config.
backend=FLEX_ATTENTION to use "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
"FlexAttention backend which supports all head sizes."
)
)
...
...
vllm/v1/attention/backends/utils.py
View file @
a810671a
...
@@ -4,6 +4,7 @@ import abc
...
@@ -4,6 +4,7 @@ import abc
import
enum
import
enum
import
functools
import
functools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
,
field
,
fields
,
make_dataclass
from
dataclasses
import
dataclass
,
field
,
fields
,
make_dataclass
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
...
@@ -201,10 +202,11 @@ def _make_metadata_with_slice(
...
@@ -201,10 +202,11 @@ def _make_metadata_with_slice(
)
)
# NOTE: last token can be outside of the last request if we have CG padding.
# NOTE: last token can be outside of the last request if we have CG padding.
# If the "middle" request has tokens in both ubatches, we have to split it.
# If the request is split across ubatches, we have to adjust the metadata.
# If ubatch_slice is the first ubatch then we will be splitting the last
# splits_first_request: The first request in this slice is the continuation of
# request. If it's the second microbatch, then we will be splitting the
# a request that started in a previous slice.
# first request
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request
=
first_tok
>
start_locs
[
first_req
]
splits_first_request
=
first_tok
>
start_locs
[
first_req
]
splits_last_request
=
last_tok
<
start_locs
[
last_req
+
1
]
-
1
splits_last_request
=
last_tok
<
start_locs
[
last_req
+
1
]
-
1
...
@@ -225,7 +227,10 @@ def _make_metadata_with_slice(
...
@@ -225,7 +227,10 @@ def _make_metadata_with_slice(
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
if
splits_last_request
:
if
splits_last_request
:
tokens_skipped
=
query_start_loc_cpu
[
-
1
]
-
token_slice
.
stop
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped
=
start_locs
[
last_req
+
1
]
-
token_slice
.
stop
query_start_loc
[
-
1
]
-=
tokens_skipped
query_start_loc
[
-
1
]
-=
tokens_skipped
query_start_loc_cpu
[
-
1
]
-=
tokens_skipped
query_start_loc_cpu
[
-
1
]
-=
tokens_skipped
...
@@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
...
@@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
# length that will be pulled into the front of the batch.
reorder_batch_threshold
:
int
|
None
=
None
reorder_batch_threshold
:
int
|
None
=
None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table
:
bool
=
False
@
abstractmethod
@
abstractmethod
def
__init__
(
def
__init__
(
...
@@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
...
@@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
update_block_table
(
self
,
metadata
:
M
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
M
:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise
NotImplementedError
def
build_for_cudagraph_capture
(
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
)
->
M
:
...
@@ -599,7 +622,7 @@ def make_local_attention_virtual_batches(
...
@@ -599,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size
:
int
,
attn_chunk_size
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
block_size
:
int
=
0
,
block_size
:
int
=
0
,
)
->
CommonAttentionMetadata
:
)
->
tuple
[
CommonAttentionMetadata
,
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]]
:
query_start_loc_np
=
common_attn_metadata
.
query_start_loc_cpu
.
numpy
()
query_start_loc_np
=
common_attn_metadata
.
query_start_loc_cpu
.
numpy
()
seq_lens_np
=
common_attn_metadata
.
seq_lens_cpu
.
numpy
()
seq_lens_np
=
common_attn_metadata
.
seq_lens_cpu
.
numpy
()
block_table
=
common_attn_metadata
.
block_table_tensor
block_table
=
common_attn_metadata
.
block_table_tensor
...
@@ -711,9 +734,12 @@ def make_local_attention_virtual_batches(
...
@@ -711,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf.
# tensor first, which recovers perf.
batch_indices_torch
=
torch
.
from_numpy
(
batch_indices
)
batch_indices_torch
=
torch
.
from_numpy
(
batch_indices
)
block_indices_torch
=
torch
.
from_numpy
(
block_indices
)
block_indices_torch
=
torch
.
from_numpy
(
block_indices
)
block_table_local
=
block_table
[
batch_indices_torch
,
block_indices_torch
].
view
(
virtual_batches
,
-
1
# Save as a lambda so we can return this for update_block_table
)
make_block_table
=
lambda
block_table
:
block_table
[
batch_indices_torch
,
block_indices_torch
].
view
(
virtual_batches
,
-
1
)
block_table_local
=
make_block_table
(
block_table
)
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
...
@@ -732,7 +758,7 @@ def make_local_attention_virtual_batches(
...
@@ -732,7 +758,7 @@ def make_local_attention_virtual_batches(
causal
=
True
,
causal
=
True
,
_seq_lens_cpu
=
seq_lens_cpu
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
_num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
)
)
,
make_block_table
def
make_kv_sharing_fast_prefill_common_attn_metadata
(
def
make_kv_sharing_fast_prefill_common_attn_metadata
(
...
...
vllm/v1/core/sched/scheduler.py
View file @
a810671a
...
@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
...
@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.perf
import
ModelMetrics
,
PerfStats
from
vllm.v1.metrics.stats
import
(
from
vllm.v1.metrics.stats
import
(
PrefixCacheStats
,
PrefixCacheStats
,
SchedulerStats
,
SchedulerStats
,
...
@@ -187,6 +188,12 @@ class Scheduler(SchedulerInterface):
...
@@ -187,6 +188,12 @@ class Scheduler(SchedulerInterface):
if
self
.
is_encoder_decoder
if
self
.
is_encoder_decoder
else
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
else
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
)
)
# For encoder-decoder models, allocate the maximum number of tokens for Cross
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self
.
_num_encoder_max_input_tokens
=
(
MULTIMODAL_REGISTRY
.
get_encdec_max_encoder_len
(
vllm_config
.
model_config
)
)
speculative_config
=
vllm_config
.
speculative_config
speculative_config
=
vllm_config
.
speculative_config
self
.
use_eagle
=
False
self
.
use_eagle
=
False
...
@@ -213,6 +220,10 @@ class Scheduler(SchedulerInterface):
...
@@ -213,6 +220,10 @@ class Scheduler(SchedulerInterface):
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
use_v2_model_runner
=
envs
.
VLLM_USE_V2_MODEL_RUNNER
self
.
use_v2_model_runner
=
envs
.
VLLM_USE_V2_MODEL_RUNNER
self
.
perf_metrics
:
ModelMetrics
|
None
=
None
if
self
.
log_stats
and
vllm_config
.
observability_config
.
enable_mfu_metrics
:
self
.
perf_metrics
=
ModelMetrics
(
vllm_config
)
def
schedule
(
self
)
->
SchedulerOutput
:
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
@@ -568,17 +579,11 @@ class Scheduler(SchedulerInterface):
...
@@ -568,17 +579,11 @@ class Scheduler(SchedulerInterface):
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
)
)
# Determine if we need to allocate cross-attention blocks.
num_encoder_tokens
=
(
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
:
self
.
_num_encoder_max_input_tokens
# TODO(russellb): For Whisper, we know that the input is
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
# always padded to the maximum length. If we support other
else
0
# encoder-decoder models, this will need to be updated if we
)
# want to only allocate what is needed.
num_encoder_tokens
=
(
self
.
scheduler_config
.
max_num_encoder_input_tokens
)
else
:
num_encoder_tokens
=
0
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
request
,
...
@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface):
...
@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface):
kv_connector_output
=
model_runner_output
.
kv_connector_output
kv_connector_output
=
model_runner_output
.
kv_connector_output
cudagraph_stats
=
model_runner_output
.
cudagraph_stats
cudagraph_stats
=
model_runner_output
.
cudagraph_stats
perf_stats
:
PerfStats
|
None
=
None
if
self
.
perf_metrics
and
self
.
perf_metrics
.
is_enabled
():
perf_stats
=
self
.
perf_metrics
.
get_step_perf_stats_per_gpu
(
scheduler_output
)
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
(
kv_connector_stats
:
KVConnectorStats
|
None
=
(
...
@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface):
if
(
if
(
stats
:
=
self
.
make_stats
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
,
cudagraph_stats
spec_decoding_stats
,
kv_connector_stats
,
cudagraph_stats
,
perf_stats
)
)
)
is
not
None
:
)
is
not
None
:
# Return stats to only one of the front-ends.
# Return stats to only one of the front-ends.
...
@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
,
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
,
kv_connector_stats
:
KVConnectorStats
|
None
=
None
,
kv_connector_stats
:
KVConnectorStats
|
None
=
None
,
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
,
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
,
perf_stats
:
PerfStats
|
None
=
None
,
)
->
SchedulerStats
|
None
:
)
->
SchedulerStats
|
None
:
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
None
return
None
...
@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface):
...
@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats
=
spec_stats
,
spec_decoding_stats
=
spec_stats
,
kv_connector_stats
=
connector_stats_payload
,
kv_connector_stats
=
connector_stats_payload
,
cudagraph_stats
=
cudagraph_stats
,
cudagraph_stats
=
cudagraph_stats
,
perf_stats
=
perf_stats
,
)
)
def
make_spec_decoding_stats
(
def
make_spec_decoding_stats
(
...
...
vllm/v1/engine/core.py
View file @
a810671a
...
@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
...
@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
(
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestType
,
FinishReason
,
ReconfigureDistributedRequest
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
ReconfigureRankType
,
UtilityOutput
,
UtilityOutput
,
...
@@ -923,6 +925,13 @@ class EngineCoreProc(EngineCore):
...
@@ -923,6 +925,13 @@ class EngineCoreProc(EngineCore):
# Post-step hook.
# Post-step hook.
self
.
post_step
(
model_executed
)
self
.
post_step
(
model_executed
)
# If no model execution happened but there are waiting requests
# (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
# background threads (like NIXL handshake) to make progress.
# Without this, the tight polling loop can starve background threads.
if
not
model_executed
and
self
.
scheduler
.
has_unfinished_requests
():
time
.
sleep
(
0.001
)
return
model_executed
return
model_executed
def
_handle_client_request
(
def
_handle_client_request
(
...
@@ -1048,9 +1057,14 @@ class EngineCoreProc(EngineCore):
...
@@ -1048,9 +1057,14 @@ class EngineCoreProc(EngineCore):
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
# Deserialize the request data.
request
:
Any
if
request_type
==
EngineCoreRequestType
.
ADD
:
if
request_type
==
EngineCoreRequestType
.
ADD
:
request
=
add_request_decoder
.
decode
(
data_frames
)
req
:
EngineCoreRequest
=
add_request_decoder
.
decode
(
data_frames
)
request
=
self
.
preprocess_add_request
(
request
)
try
:
request
=
self
.
preprocess_add_request
(
req
)
except
Exception
:
self
.
_handle_request_preproc_error
(
req
)
continue
else
:
else
:
request
=
generic_decoder
.
decode
(
data_frames
)
request
=
generic_decoder
.
decode
(
data_frames
)
...
@@ -1134,6 +1148,30 @@ class EngineCoreProc(EngineCore):
...
@@ -1134,6 +1148,30 @@ class EngineCoreProc(EngineCore):
# Limit the number of buffers to reuse.
# Limit the number of buffers to reuse.
reuse_buffers
.
append
(
buffer
)
reuse_buffers
.
append
(
buffer
)
def
_handle_request_preproc_error
(
self
,
request
:
EngineCoreRequest
)
->
None
:
"""Log and return a request-scoped error response for exceptions raised
from the add request preprocessing in the input socket processing thread.
"""
logger
.
exception
(
"Unexpected error pre-processing request %s"
,
request
.
request_id
)
self
.
output_queue
.
put_nowait
(
(
request
.
client_index
,
EngineCoreOutputs
(
engine_index
=
self
.
engine_index
,
finished_requests
=
{
request
.
request_id
},
outputs
=
[
EngineCoreOutput
(
request_id
=
request
.
request_id
,
new_token_ids
=
[],
finish_reason
=
FinishReason
.
ERROR
,
)
],
),
)
)
class
DPEngineCoreProc
(
EngineCoreProc
):
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
"""ZMQ-wrapper for running EngineCore in background process
...
...
vllm/v1/engine/core_client.py
View file @
a810671a
...
@@ -269,7 +269,8 @@ class InprocClient(EngineCoreClient):
...
@@ -269,7 +269,8 @@ class InprocClient(EngineCoreClient):
self
.
engine_core
=
EngineCore
(
*
args
,
**
kwargs
)
self
.
engine_core
=
EngineCore
(
*
args
,
**
kwargs
)
def
get_output
(
self
)
->
EngineCoreOutputs
:
def
get_output
(
self
)
->
EngineCoreOutputs
:
outputs
,
_
=
self
.
engine_core
.
step_fn
()
outputs
,
model_executed
=
self
.
engine_core
.
step_fn
()
self
.
engine_core
.
post_step
(
model_executed
=
model_executed
)
return
outputs
and
outputs
.
get
(
0
)
or
EngineCoreOutputs
()
return
outputs
and
outputs
.
get
(
0
)
or
EngineCoreOutputs
()
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
...
...
vllm/v1/engine/input_processor.py
View file @
a810671a
...
@@ -24,7 +24,10 @@ from vllm.tokenizers.mistral import MistralTokenizer
...
@@ -24,7 +24,10 @@ from vllm.tokenizers.mistral import MistralTokenizer
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.structured_output.backend_guidance
import
validate_guidance_grammar
from
vllm.v1.structured_output.backend_guidance
import
(
has_guidance_unsupported_json_features
,
validate_guidance_grammar
,
)
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
validate_structured_output_request_lm_format_enforcer
,
validate_structured_output_request_lm_format_enforcer
,
)
)
...
@@ -340,8 +343,22 @@ class InputProcessor:
...
@@ -340,8 +343,22 @@ class InputProcessor:
# The request either failed validation
# The request either failed validation
# or includes some jsonschema feature(s) that
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
# are not supported in xgrammar.
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
):
# Check if schema has features unsupported by guidance
so_params
=
params
.
structured_outputs
skip_guidance
=
False
if
so_params
.
json
:
if
isinstance
(
so_params
.
json
,
str
):
import
json
schema
=
json
.
loads
(
so_params
.
json
)
else
:
schema
=
so_params
.
json
skip_guidance
=
has_guidance_unsupported_json_features
(
schema
)
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
)
or
skip_guidance
:
# Fall back to outlines if the tokenizer is Mistral
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines
(
params
)
validate_structured_output_request_outlines
(
params
)
params
.
structured_outputs
.
_backend
=
"outlines"
params
.
structured_outputs
.
_backend
=
"outlines"
else
:
else
:
...
...
vllm/v1/engine/output_processor.py
View file @
a810671a
...
@@ -8,6 +8,7 @@ from typing import Any, cast
...
@@ -8,6 +8,7 @@ from typing import Any, cast
import
torch
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
(
from
vllm.outputs
import
(
CompletionOutput
,
CompletionOutput
,
PoolingOutput
,
PoolingOutput
,
...
@@ -93,7 +94,7 @@ class RequestState:
...
@@ -93,7 +94,7 @@ class RequestState:
request_id
:
str
,
request_id
:
str
,
parent_req
:
ParentRequest
|
None
,
parent_req
:
ParentRequest
|
None
,
request_index
:
int
,
request_index
:
int
,
lora_
name
:
st
r
|
None
,
lora_
request
:
LoRAReque
st
|
None
,
output_kind
:
RequestOutputKind
,
output_kind
:
RequestOutputKind
,
prompt
:
str
|
None
,
prompt
:
str
|
None
,
prompt_token_ids
:
list
[
int
]
|
None
,
prompt_token_ids
:
list
[
int
]
|
None
,
...
@@ -112,7 +113,8 @@ class RequestState:
...
@@ -112,7 +113,8 @@ class RequestState:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
parent_req
=
parent_req
self
.
parent_req
=
parent_req
self
.
request_index
=
request_index
self
.
request_index
=
request_index
self
.
lora_name
=
lora_name
self
.
lora_request
=
lora_request
self
.
lora_name
=
lora_request
.
lora_name
if
lora_request
is
not
None
else
None
self
.
output_kind
=
output_kind
self
.
output_kind
=
output_kind
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
...
@@ -178,9 +180,7 @@ class RequestState:
...
@@ -178,9 +180,7 @@ class RequestState:
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
parent_req
=
parent_req
,
parent_req
=
parent_req
,
request_index
=
request_index
,
request_index
=
request_index
,
lora_name
=
(
lora_request
=
request
.
lora_request
,
request
.
lora_request
.
name
if
request
.
lora_request
is
not
None
else
None
),
output_kind
=
output_kind
,
output_kind
=
output_kind
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
...
@@ -289,6 +289,7 @@ class RequestState:
...
@@ -289,6 +289,7 @@ class RequestState:
return
RequestOutput
(
return
RequestOutput
(
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
prompt
=
self
.
prompt
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
prompt_logprobs
=
prompt_logprobs
,
...
...
vllm/v1/metrics/loggers.py
View file @
a810671a
...
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
...
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.plugins
import
STAT_LOGGER_PLUGINS_GROUP
,
load_plugins_by_group
from
vllm.plugins
import
STAT_LOGGER_PLUGINS_GROUP
,
load_plugins_by_group
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.perf
import
PerfMetricsLogging
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
(
from
vllm.v1.metrics.stats
import
(
CachingMetrics
,
CachingMetrics
,
...
@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
engine_is_idle
=
False
self
.
engine_is_idle
=
False
self
.
aggregated
=
False
self
.
aggregated
=
False
if
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
=
PerfMetricsLogging
(
vllm_config
)
def
_reset
(
self
,
now
):
def
_reset
(
self
,
now
):
self
.
last_log_time
=
now
self
.
last_log_time
=
now
...
@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_corrupted_reqs
:
int
=
0
self
.
num_corrupted_reqs
:
int
=
0
self
.
num_preemptions
:
int
=
0
self
.
num_preemptions
:
int
=
0
def
_enable_perf_stats
(
self
)
->
bool
:
return
self
.
vllm_config
.
observability_config
.
enable_mfu_metrics
def
_track_iteration_stats
(
self
,
iteration_stats
:
IterationStats
):
def
_track_iteration_stats
(
self
,
iteration_stats
:
IterationStats
):
# Save tracked stats for token counters.
# Save tracked stats for token counters.
self
.
num_prompt_tokens
+=
iteration_stats
.
num_prompt_tokens
self
.
num_prompt_tokens
+=
iteration_stats
.
num_prompt_tokens
...
@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
cudagraph_logging
.
observe
(
scheduler_stats
.
cudagraph_stats
)
self
.
cudagraph_logging
.
observe
(
scheduler_stats
.
cudagraph_stats
)
if
not
self
.
aggregated
:
if
not
self
.
aggregated
:
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
if
(
perf_stats
:
=
scheduler_stats
.
perf_stats
)
and
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
.
observe
(
perf_stats
)
if
mm_cache_stats
:
if
mm_cache_stats
:
self
.
mm_caching_metrics
.
observe
(
mm_cache_stats
)
self
.
mm_caching_metrics
.
observe
(
mm_cache_stats
)
...
@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase):
"Running: %d reqs"
,
"Running: %d reqs"
,
"Waiting: %d reqs"
,
"Waiting: %d reqs"
,
]
]
log_args
=
[
log_args
:
list
[
int
|
float
|
str
]
=
[
self
.
last_prompt_throughput
,
self
.
last_prompt_throughput
,
self
.
last_generation_throughput
,
self
.
last_generation_throughput
,
self
.
last_scheduler_stats
.
num_running_reqs
,
self
.
last_scheduler_stats
.
num_running_reqs
,
...
@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
kv_connector_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_connector_logging
.
log
(
log_fn
=
log_fn
)
if
self
.
cudagraph_logging
is
not
None
:
if
self
.
cudagraph_logging
is
not
None
:
self
.
cudagraph_logging
.
log
(
log_fn
=
log_fn
)
self
.
cudagraph_logging
.
log
(
log_fn
=
log_fn
)
if
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
.
log
(
log_fn
=
log_fn
,
log_prefix
=
self
.
log_prefix
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
...
@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
def
log_prefix
(
self
):
def
log_prefix
(
self
):
return
"{} Engines Aggregated: "
.
format
(
len
(
self
.
engine_indexes
))
return
"{} Engines Aggregated: "
.
format
(
len
(
self
.
engine_indexes
))
def
_enable_perf_stats
(
self
)
->
bool
:
# Adding per_gpu perf stats across engines can lead to misleading numbers.
return
False
def
record
(
def
record
(
self
,
self
,
scheduler_stats
:
SchedulerStats
|
None
,
scheduler_stats
:
SchedulerStats
|
None
,
...
...
vllm/v1/metrics/perf.py
0 → 100644
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Analytic flops/memory estimation module for transformer components,
to help derive MFU (Model Flops Utilization) stats for a running model.
"""
import
json
import
time
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Any
,
Protocol
import
torch
from
pydantic
import
BaseModel
,
Field
,
ValidationError
,
model_validator
from
typing_extensions
import
Self
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
get_kv_cache_torch_dtype
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
class
InvalidComponent
(
Exception
):
"""
Custom exception to indicate that a certain ComponentMetric is not
applicable to the given VllmConfig.
"""
pass
#### Basic Data Types ####
@
dataclass
class
DebugPerfStats
:
## Stats for debugging the metrics calculation
calc_duration
:
float
=
0.0
# time spent calculating these stats
num_prefill_requests
:
int
=
0
num_decode_requests
:
int
=
0
context_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_flops_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_read_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_write_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
@
dataclass
class
PerfStats
:
num_flops_per_gpu
:
int
=
0
num_read_bytes_per_gpu
:
int
=
0
num_write_bytes_per_gpu
:
int
=
0
debug_stats
:
DebugPerfStats
|
None
=
None
@
dataclass
class
ExecutionContext
:
"""
Represents an execution context for a batch of requests.
This class aggregates statistics across multiple requests in a batch,
separately tracking prefill and decode phases.
Example)
- Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context):
ctx = ExecutionContext()
ctx.add(2048, 2048, is_prefill=True)
ctx.add(1, 8192, is_prefill=False)
"""
# Prefill phase statistics
num_prefill_requests
:
int
=
0
prefill_num_tokens
:
int
=
0
# sum of num_tokens for prefill requests
prefill_context_len
:
int
=
0
# sum of context_len for prefill requests
prefill_token_context_product
:
int
=
0
# sum of (num_tokens * context_len)
# Decode phase statistics
num_decode_requests
:
int
=
0
decode_num_tokens
:
int
=
0
# sum of num_tokens for decode requests
decode_context_len
:
int
=
0
# sum of context_len for decode requests
decode_token_context_product
:
int
=
0
# sum of (num_tokens * context_len)
def
add
(
self
,
num_tokens
:
int
,
context_len
:
int
,
is_prefill
:
bool
)
->
None
:
"""Add a single request's statistics to this batch context."""
if
is_prefill
:
self
.
num_prefill_requests
+=
1
self
.
prefill_num_tokens
+=
num_tokens
self
.
prefill_context_len
+=
context_len
self
.
prefill_token_context_product
+=
num_tokens
*
context_len
else
:
self
.
num_decode_requests
+=
1
self
.
decode_num_tokens
+=
num_tokens
self
.
decode_context_len
+=
context_len
self
.
decode_token_context_product
+=
num_tokens
*
context_len
def
total_num_tokens
(
self
)
->
int
:
"""Total number of tokens across all requests in the batch."""
return
self
.
prefill_num_tokens
+
self
.
decode_num_tokens
def
total_token_context_product
(
self
)
->
int
:
"""Total sum of (num_tokens * context_len) across all requests."""
return
self
.
prefill_token_context_product
+
self
.
decode_token_context_product
@
classmethod
def
from_single_request
(
cls
,
num_tokens
:
int
,
context_len
:
int
,
is_prefill
:
bool
)
->
"ExecutionContext"
:
"""Create an ExecutionContext from a single request.
This is a convenience method primarily for testing.
"""
ctx
=
cls
()
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
)
return
ctx
class
ParsedArgs
:
"""
Syntactic sugar so that Parsers can use dot notations
to access/update the parsed arguments.
e.g.)
args = ParsedArgs()
args.x = 3
args.y = args.x + 1
"""
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' has no attribute '
{
name
}
'"
)
def
__setattr__
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
object
.
__setattr__
(
self
,
name
,
value
)
def
model_dump
(
self
)
->
dict
[
str
,
Any
]:
return
vars
(
self
).
copy
()
#### Abstract ####
class
Parser
(
Protocol
):
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
"""
Parse the vllm config and update the current ParsedArgs and pass it on.
If the parser isn't applicable to the vllm_config, it will do nothing.
"""
...
class
ParserChain
:
"""
Applies chain of parser in a sequential order.
Later parsers might overwrite results from previous parsers,
so parsers should be chained in the appropriate order if they
are not mutually exclusive.
"""
def
__init__
(
self
,
*
parsers
:
Parser
)
->
None
:
self
.
parsers
=
list
(
parsers
)
def
add_parser
(
self
,
parser
:
Parser
)
->
None
:
self
.
parsers
.
append
(
parser
)
def
parse
(
self
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
args
=
ParsedArgs
()
for
parser
in
self
.
parsers
:
args
=
parser
.
parse
(
args
,
vllm_config
)
return
args
_COMPONENT_METRICS_REGISTRY
:
dict
[
str
,
type
[
"ComponentMetrics"
]]
=
{}
class
ComponentMetrics
(
BaseModel
,
ABC
):
"""
Each concrete ComponentMetrics class is associated with:
- fields that are required for metric derivation
(fields are specified/validated through pydantic model)
- parser to parse VllmConfig into fields
- metric methods that derive flops/bytes for a given execution context
"""
@
classmethod
@
abstractmethod
def
component_type
(
cls
)
->
str
:
...
@
classmethod
@
abstractmethod
def
get_parser
(
cls
)
->
ParserChain
:
"""
Return a ParserChain that provides values for all required fields.
The returned parser chain must populate ParsedArgs with values for every
field defined on this ComponentMetrics class. Missing fields will cause
a ValidationError when from_vllm_config() is called.
See individual Parser docstrings for which args they provide, and field
comments on ComponentMetrics subclasses for which parser provides each field.
"""
...
def
__init_subclass__
(
cls
):
_COMPONENT_METRICS_REGISTRY
[
cls
.
component_type
()]
=
cls
@
classmethod
def
from_vllm_config
(
cls
,
vllm_config
:
VllmConfig
)
->
Self
:
"""
Instantiate this class from VllmConfig.
Raises ValidationError if parsing fails.
"""
parser
=
cls
.
get_parser
()
parsed_args
=
parser
.
parse
(
vllm_config
)
try
:
return
cls
.
model_validate
(
parsed_args
.
model_dump
())
except
ValidationError
as
e
:
raise
InvalidComponent
(
f
"Invalid
{
cls
.
component_type
()
}
config:
{
e
}
"
)
from
e
@
classmethod
def
registered_metrics
(
cls
)
->
Iterable
[
type
[
"ComponentMetrics"
]]:
return
iter
(
_COMPONENT_METRICS_REGISTRY
.
values
())
@
abstractmethod
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
@
abstractmethod
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
@
abstractmethod
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
def
get_num_flops
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_num_flops_breakdown
(
ctx
,
per_gpu
).
values
())
def
get_read_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
).
values
())
def
get_write_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_write_bytes_breakdown
(
ctx
,
per_gpu
).
values
())
#### parsers ####
class
BaseConfigParser
(
Parser
):
"""
Parses base model configuration.
Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers,
weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
model_config
=
vllm_config
.
model_config
args
.
vocab_size
=
model_config
.
get_vocab_size
()
args
.
hidden_size
=
model_config
.
get_hidden_size
()
# NOTE: model_config.get_attention_heads() divide by TP
# so we access field manually here to get total num_heads
args
.
num_attention_heads
=
get_required
(
model_config
.
hf_text_config
,
"num_attention_heads"
)
args
.
num_hidden_layers
=
get_required
(
model_config
.
hf_text_config
,
"num_hidden_layers"
)
model_dtype
=
vllm_config
.
model_config
.
dtype
if
isinstance
(
model_dtype
,
torch
.
dtype
):
torch_dtype
=
model_dtype
elif
isinstance
(
model_dtype
,
str
)
and
model_dtype
in
STR_DTYPE_TO_TORCH_DTYPE
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
else
:
# FIXME: handle this better
logger
.
warning
(
"Unknown model_dtype %s, defaulting to bfloat16"
,
model_dtype
,
)
torch_dtype
=
torch
.
bfloat16
args
.
weight_byte_size
=
get_dtype_size
(
torch_dtype
)
# FIXME: handle this better by parsing whether activations use
# bf16, fp32, etc...
args
.
activation_byte_size
=
2
args
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
args
.
tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
args
.
pp_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
args
.
enable_ep
=
vllm_config
.
parallel_config
.
enable_expert_parallel
return
args
#### Attention ####
class
BaseAttentionConfigParser
(
Parser
):
"""
Parses attention-specific configuration.
Provides: num_key_value_heads, head_dim, cache_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
model_config
=
vllm_config
.
model_config
args
.
num_key_value_heads
=
model_config
.
get_total_num_kv_heads
()
args
.
head_dim
=
model_config
.
get_head_size
()
model_dtype
=
vllm_config
.
model_config
.
dtype
cache_dtype
=
vllm_config
.
cache_config
.
cache_dtype
kv_cache_torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
args
.
cache_byte_size
=
get_dtype_size
(
kv_cache_torch_dtype
)
return
args
class
AttentionQuantizationConfigParser
(
Parser
):
"""
Parses quantization configuration for attention layers.
Overrides: weight_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
quant_config
if
cfg
is
None
:
return
args
quant_method
=
cfg
.
get_name
()
if
quant_method
in
[
"fp8"
,
"fbgemm_fp8"
]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args
.
weight_byte_size
=
1
elif
quant_method
==
"mxfp4"
:
# FIXME: Also has "ignored layers" issue above
args
.
weight_byte_size
=
0.5
else
:
# FIXME: Add more parsing logic for different quant methods.
raise
InvalidComponent
return
args
class
AttentionMetrics
(
ComponentMetrics
):
# From BaseConfigParser
num_hidden_layers
:
int
=
Field
(...,
gt
=
0
)
hidden_size
:
int
=
Field
(...,
gt
=
0
)
num_attention_heads
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
tp_size
:
int
=
Field
(...,
gt
=
0
)
pp_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseAttentionConfigParser
num_key_value_heads
:
int
=
Field
(...,
gt
=
0
)
head_dim
:
int
=
Field
(...,
gt
=
0
)
cache_byte_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseConfig Parser, overridden by AttentionQuantizationConfigParser
weight_byte_size
:
int
|
float
=
Field
(...,
gt
=
0
)
# TODO: discern cases where we have mixture of different attention layer types
# such as SWA, MLA, etc.
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"attn"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
BaseAttentionConfigParser
(),
AttentionQuantizationConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
TC
=
ctx
.
total_token_context_product
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
return
{
"qkv_proj"
:
2
*
T
*
D
*
(
q
+
2
*
kv
)
*
d
*
L
,
"attn_qk"
:
2
*
q
*
TC
*
d
*
L
,
"attn_av"
:
2
*
q
*
TC
*
d
*
L
,
"out_proj"
:
2
*
T
*
D
*
q
*
d
*
L
,
}
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
read_bytes
=
{}
read_bytes
[
"qkv_input"
]
=
T
*
D
*
self
.
activation_byte_size
*
L
read_bytes
[
"qkv_weight"
]
=
int
(
D
*
(
q
+
2
*
kv
)
*
d
*
self
.
weight_byte_size
*
L
)
# Attention input reads differ between prefill and decode
# Prefill: read Q, K, V activations (all in activation_byte_size)
if
ctx
.
prefill_num_tokens
>
0
:
read_bytes
[
"attn_input"
]
=
(
(
ctx
.
prefill_num_tokens
*
q
+
2
*
ctx
.
prefill_context_len
*
kv
)
*
d
*
self
.
activation_byte_size
*
L
)
# Decode: read Q activations + read K, V from cache (in cache_byte_size)
if
ctx
.
decode_num_tokens
>
0
:
read_bytes
[
"attn_input"
]
=
read_bytes
.
get
(
"attn_input"
,
0
)
+
(
ctx
.
decode_num_tokens
*
q
*
d
*
self
.
activation_byte_size
*
L
+
2
*
ctx
.
decode_context_len
*
kv
*
d
*
self
.
cache_byte_size
*
L
)
read_bytes
[
"out_input"
]
=
T
*
q
*
d
*
self
.
activation_byte_size
*
L
read_bytes
[
"out_weight"
]
=
int
(
q
*
d
*
D
*
self
.
weight_byte_size
*
L
)
return
read_bytes
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for attention layers."""
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
return
{
"qkv_output"
:
T
*
(
q
+
2
*
kv
)
*
d
*
self
.
activation_byte_size
*
L
,
"kv_cache"
:
2
*
T
*
kv
*
d
*
self
.
cache_byte_size
*
L
,
"out_output"
:
T
*
D
*
self
.
activation_byte_size
*
L
,
}
#### Ffn ####
class
BaseFfnConfigParser
(
Parser
):
"""
Parses FFN and MoE configuration.
Provides: intermediate_size, num_experts, num_experts_per_tok,
moe_intermediate_size, num_shared_experts, num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
args
.
intermediate_size
=
getattr
(
cfg
,
"intermediate_size"
,
args
.
hidden_size
*
4
)
# Try different naming conventions.
args
.
num_experts
=
vllm_config
.
model_config
.
get_num_experts
()
args
.
num_experts_per_tok
=
getattr_from_list
(
cfg
,
[
"num_experts_per_tok"
,
"moe_topk"
],
0
)
args
.
moe_intermediate_size
=
getattr_from_list
(
cfg
,
[
"moe_intermediate_size"
,
"intermediate_size"
],
0
)
args
.
num_shared_experts
=
getattr_from_list
(
cfg
,
[
"n_shared_experts"
,
"num_shared_experts"
],
0
)
is_moe
=
args
.
num_experts
!=
0
# Assume all MoE layers by default
args
.
num_moe_layers
=
args
.
num_hidden_layers
if
is_moe
else
0
return
args
class
FfnParallelParser
(
Parser
):
"""
Parses FFN parallelism configuration.
Provides: ffn_tp_size, ffn_ep_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
# NOTE: ffn tp_size does not equal the tp_size parameter directly.
# e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.)
if
args
.
enable_ep
:
ffn_tp_size
,
ffn_ep_size
=
1
,
args
.
dp_size
*
args
.
tp_size
else
:
ffn_tp_size
,
ffn_ep_size
=
args
.
dp_size
*
args
.
tp_size
,
1
args
.
ffn_tp_size
=
ffn_tp_size
args
.
ffn_ep_size
=
ffn_ep_size
return
args
class
InterleaveMoeLayerStepParser
(
Parser
):
"""
Parses interleave_moe_layer_step field for models like Llama4.
Overrides: num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
if
(
hasattr
(
cfg
,
"interleave_moe_layer_step"
)
and
cfg
.
interleave_moe_layer_step
>
0
):
args
.
num_moe_layers
=
len
(
[
layer
for
layer
in
range
(
args
.
num_hidden_layers
)
if
(
layer
+
1
)
%
cfg
.
interleave_moe_layer_step
==
0
]
)
return
args
class
MoeLayerFreqParser
(
Parser
):
"""
Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek.
Overrides: num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
if
hasattr
(
cfg
,
"moe_layer_freq"
)
and
hasattr
(
cfg
,
"first_k_dense_replace"
):
args
.
num_moe_layers
=
len
(
[
layer
for
layer
in
range
(
args
.
num_hidden_layers
)
if
layer
>=
cfg
.
first_k_dense_replace
and
layer
%
cfg
.
moe_layer_freq
==
0
]
)
return
args
class
FfnQuantizationConfigParser
(
Parser
):
"""
Parses quantization configuration for FFN layers.
Overrides: weight_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
quant_config
if
cfg
is
None
:
return
args
quant_method
=
cfg
.
get_name
()
if
quant_method
in
[
"fp8"
,
"fbgemm_fp8"
]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# (there might be more quantization methods for fp8).
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args
.
weight_byte_size
=
1
pass
elif
quant_method
==
"mxfp4"
:
# FIXME: Also has "ignored layers" issue above
args
.
weight_byte_size
=
0.5
else
:
# FIXME: Add more parsing logic for different quant methods.
raise
InvalidComponent
return
args
class
FfnMetrics
(
ComponentMetrics
):
# From BaseConfigParser
num_hidden_layers
:
int
=
Field
(...,
gt
=
0
)
hidden_size
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
pp_size
:
int
=
Field
(...,
gt
=
0
)
# From FfnParallelParser
ffn_tp_size
:
int
=
Field
(...,
gt
=
0
)
ffn_ep_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseFfnConfigParser
intermediate_size
:
int
=
Field
(...,
gt
=
0
)
num_experts
:
int
=
Field
(
0
)
num_experts_per_tok
:
int
=
Field
(
1
)
moe_intermediate_size
:
int
=
Field
(
0
)
num_shared_experts
:
int
=
Field
(
0
)
# From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq
num_moe_layers
:
int
=
Field
(...,
ge
=
0
)
# FIXME: might have to make this more granular
# (i.e. dense_weight_byte_size, moe_routed_weight_byte_size,
# moe_shared_weight_byte_size)
# since it can differ from byte size of other components (e.g. attn)
# and can differ even from each other.
# From BaseConfigParser, can be overridden by FfnQuantizationConfigParser
weight_byte_size
:
int
|
float
=
Field
(...,
gt
=
0
)
@
model_validator
(
mode
=
"after"
)
def
validate_moe_fields
(
self
)
->
Self
:
"""Validate that MoE-related fields are properly set when num_moe_layers > 0."""
if
self
.
num_moe_layers
>
0
:
assert
self
.
num_experts
,
f
"
{
self
.
num_experts
=
}
"
assert
self
.
num_experts_per_tok
,
f
"
{
self
.
num_experts_per_tok
=
}
"
assert
self
.
moe_intermediate_size
,
f
"
{
self
.
moe_intermediate_size
=
}
"
return
self
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"ffn"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
FfnParallelParser
(),
BaseFfnConfigParser
(),
InterleaveMoeLayerStepParser
(),
MoeLayerFreqParser
(),
FfnQuantizationConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate flops breakdown for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
flops
=
{}
# Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down)
if
Ld
:
flops
[
"dense_ffn"
]
=
2
*
D
*
3
*
DI
*
T
*
Ld
# MoE routed experts (each token activates E experts)
if
Lm
and
E
:
flops
[
"routed_ffn"
]
=
2
*
D
*
3
*
MI
*
num_activated_tokens
*
Lm
# MoE shared experts (all S shared experts run for every token)
if
Lm
and
S
:
flops
[
"shared_ffn"
]
=
2
*
D
*
3
*
MI
*
S
*
T
*
Lm
return
flops
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate read memory traffic for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
num_experts
=
self
.
num_experts
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
if
num_experts
is
not
None
:
num_experts
//=
self
.
ffn_ep_size
read_bytes
=
{}
# Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation)
if
Ld
:
read_bytes
[
"dense_up_gate_input"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_up_gate_weights"
]
=
int
(
2
*
D
*
DI
*
self
.
weight_byte_size
*
Ld
)
read_bytes
[
"dense_silu_input"
]
=
int
(
2
*
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_down_input"
]
=
int
(
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_down_weights"
]
=
int
(
D
*
DI
*
self
.
weight_byte_size
*
Ld
)
if
Lm
:
# MoE routed expert reads
if
E
:
# FIXME: Assume perfect load balancing for now.
num_activated_experts
=
min
(
num_activated_tokens
,
num_experts
)
read_bytes
[
"routed_up_gate_input"
]
=
int
(
num_activated_tokens
*
D
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_up_gate_weights"
]
=
int
(
2
*
D
*
MI
*
num_activated_experts
*
self
.
weight_byte_size
*
Lm
)
read_bytes
[
"routed_silu_input"
]
=
int
(
2
*
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_down_input"
]
=
int
(
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_down_weights"
]
=
int
(
D
*
MI
*
num_activated_experts
*
self
.
weight_byte_size
*
Lm
)
# MoE shared expert reads
if
S
:
read_bytes
[
"shared_up_gate_input"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_up_gate_weights"
]
=
int
(
2
*
D
*
MI
*
S
*
self
.
weight_byte_size
*
Lm
)
read_bytes
[
"shared_silu_input"
]
=
int
(
2
*
T
*
MI
*
S
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_down_input"
]
=
int
(
T
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_down_weights"
]
=
int
(
D
*
MI
*
S
*
self
.
weight_byte_size
*
Lm
)
return
read_bytes
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
write_bytes
=
{}
# Dense FFN layers
if
Ld
:
write_bytes
[
"dense_up_gate_output"
]
=
int
(
2
*
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
write_bytes
[
"dense_silu_output"
]
=
int
(
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
write_bytes
[
"dense_down_output"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Ld
)
# MoE outputs
if
Lm
:
if
E
:
write_bytes
[
"routed_up_gate_output"
]
=
int
(
2
*
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"routed_silu_output"
]
=
int
(
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"routed_down_output"
]
=
int
(
num_activated_tokens
*
D
*
self
.
activation_byte_size
*
Lm
)
if
S
:
write_bytes
[
"shared_up_gate_output"
]
=
int
(
2
*
T
*
S
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"shared_silu_output"
]
=
int
(
T
*
S
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"shared_down_output"
]
=
int
(
T
*
S
*
D
*
self
.
activation_byte_size
*
Lm
)
return
write_bytes
#### Unembed ####
class
UnembedMetrics
(
ComponentMetrics
):
# From BaseConfigParser
hidden_size
:
int
=
Field
(...,
gt
=
0
)
vocab_size
:
int
=
Field
(...,
gt
=
0
)
weight_byte_size
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
tp_size
:
int
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"unembed"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate flops breakdown for unembedding layer."""
D
,
V
=
self
.
hidden_size
,
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"unembed"
:
2
*
T
*
D
*
V
,
}
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate read memory traffic for unembedding layer."""
D
,
V
=
self
.
hidden_size
,
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"input"
:
T
*
D
*
self
.
activation_byte_size
,
"weight"
:
D
*
V
*
self
.
weight_byte_size
,
}
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for unembedding layer."""
V
=
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"output"
:
T
*
V
*
self
.
activation_byte_size
,
}
#### ModelMetrics ####
class
ModelMetrics
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
"""
Parse vllm_config to instantiate metrics for each component.
is_enabled() will return False if no component metrics could be instantiated.
"""
self
.
vllm_config
=
vllm_config
self
.
metrics
:
list
[
ComponentMetrics
]
=
[]
for
metric_cls
in
ComponentMetrics
.
registered_metrics
():
try
:
metric
=
metric_cls
.
from_vllm_config
(
vllm_config
)
self
.
metrics
.
append
(
metric
)
logger
.
info
(
"Instantiated ComponentMetrics [%s] with (%s)"
,
metric
.
component_type
(),
str
(
metric
),
)
except
InvalidComponent
as
e
:
logger
.
debug
(
"Failed to instantiate %s from %s"
,
metric_cls
.
component_type
(),
str
(
e
),
)
def
is_enabled
(
self
)
->
bool
:
return
len
(
self
.
metrics
)
>
0
def
get_num_flops
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_num_flops
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_read_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_read_bytes
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_write_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_write_bytes
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_num_flops_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_write_bytes_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_step_perf_stats_per_gpu
(
self
,
scheduler_output
:
SchedulerOutput
)
->
PerfStats
:
"""
Calculate perf stats for the current step based on scheduled tokens.
"""
t0
=
time
.
monotonic
()
# Build a single batch context
ctx
=
ExecutionContext
()
# Process new requests (these are in prefill phase)
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req
.
req_id
num_tokens
=
scheduler_output
.
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens
==
0
:
continue
# For new requests, context_len = num_computed_tokens + num_tokens
# num_computed_tokens represents previously computed tokens in the sequence
context_len
=
new_req
.
num_computed_tokens
+
num_tokens
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
=
True
)
# Process cached requests (continuing requests)
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens
==
0
:
continue
# For cached requests, we have the current num_computed_tokens
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
context_len
=
num_computed_tokens
+
num_tokens
# Cached requests are typically in decode phase (num_tokens == 1)
# unless they're doing chunked prefill (num_tokens > 1)
is_prefill
=
num_tokens
>
1
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
)
num_flops_breakdown
=
self
.
get_num_flops_breakdown
(
ctx
,
True
)
read_bytes_breakdown
=
self
.
get_read_bytes_breakdown
(
ctx
,
True
)
write_bytes_breakdown
=
self
.
get_write_bytes_breakdown
(
ctx
,
True
)
perf_stats
=
PerfStats
(
sum
(
num_flops_breakdown
.
values
()),
sum
(
read_bytes_breakdown
.
values
()),
sum
(
write_bytes_breakdown
.
values
()),
)
if
envs
.
VLLM_DEBUG_MFU_METRICS
:
perf_stats
.
debug_stats
=
DebugPerfStats
(
time
.
monotonic
()
-
t0
,
ctx
.
num_prefill_requests
,
ctx
.
num_decode_requests
,
asdict
(
ctx
),
num_flops_breakdown
,
read_bytes_breakdown
,
write_bytes_breakdown
,
)
return
perf_stats
#### Logging ####
class
PerfMetricsDebugLogging
:
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
total_calc_duration
:
float
=
0.0
self
.
total_num_prefill_requests
:
int
=
0
self
.
total_num_decode_requests
:
int
=
0
self
.
total_num_batches
:
int
=
0
self
.
total_context_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_num_flops_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_read_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_write_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
def
observe
(
self
,
debug_stats
:
DebugPerfStats
)
->
None
:
self
.
total_calc_duration
+=
debug_stats
.
calc_duration
self
.
total_num_prefill_requests
+=
debug_stats
.
num_prefill_requests
self
.
total_num_decode_requests
+=
debug_stats
.
num_decode_requests
self
.
total_num_batches
+=
1
for
dst
,
src
in
zip
(
[
self
.
total_context_breakdown
,
self
.
total_num_flops_per_gpu_breakdown
,
self
.
total_read_bytes_per_gpu_breakdown
,
self
.
total_write_bytes_per_gpu_breakdown
,
],
[
debug_stats
.
context_breakdown
,
debug_stats
.
num_flops_per_gpu_breakdown
,
debug_stats
.
num_read_bytes_per_gpu_breakdown
,
debug_stats
.
num_write_bytes_per_gpu_breakdown
,
],
):
assert
isinstance
(
src
,
dict
)
for
key
,
val
in
src
.
items
():
dst
[
key
]
=
dst
.
get
(
key
,
0
)
+
val
def
log
(
self
,
log_fn
,
log_prefix
:
str
,
delta_time
:
float
):
# pretty print breakdowns
total_num_flops_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e12
:.
1
f
}
TF"
for
k
,
v
in
self
.
total_num_flops_per_gpu_breakdown
.
items
()
}
total_read_bytes_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e9
:.
1
f
}
GB"
for
k
,
v
in
self
.
total_read_bytes_per_gpu_breakdown
.
items
()
}
total_write_bytes_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e9
:.
1
f
}
GB"
for
k
,
v
in
self
.
total_write_bytes_per_gpu_breakdown
.
items
()
}
logger
.
debug
(
"%sMFU details: %s"
,
log_prefix
,
json
.
dumps
(
{
"prefill_reqs"
:
self
.
total_num_prefill_requests
,
"decode_reqs"
:
self
.
total_num_decode_requests
,
"num_batches"
:
self
.
total_num_batches
,
"context_breakdown"
:
self
.
total_context_breakdown
,
"flops_breakdown"
:
total_num_flops_per_gpu_breakdown
,
"num_read_bytes_breakdown"
:
total_read_bytes_per_gpu_breakdown
,
"num_write_bytes_breakdown"
:
(
total_write_bytes_per_gpu_breakdown
),
"duration"
:
f
"
{
delta_time
:.
1
f
}
s"
,
"mfu_calc_overhead"
:
(
f
"
{
self
.
total_calc_duration
/
delta_time
:.
1
%
}
"
),
},
indent
=
2
,
),
)
class
PerfMetricsLogging
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
vllm_config
=
vllm_config
self
.
pp_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
self
.
debug_logging
:
PerfMetricsDebugLogging
|
None
=
None
if
envs
.
VLLM_DEBUG_MFU_METRICS
:
self
.
debug_logging
=
PerfMetricsDebugLogging
()
self
.
reset
()
def
reset
(
self
):
self
.
last_log_time
=
time
.
monotonic
()
self
.
total_num_flops_per_gpu
:
int
=
0
self
.
total_read_bytes_per_gpu
:
int
=
0
self
.
total_write_bytes_per_gpu
:
int
=
0
if
self
.
debug_logging
:
self
.
debug_logging
.
reset
()
def
observe
(
self
,
perf_stats
:
PerfStats
)
->
None
:
self
.
total_num_flops_per_gpu
+=
perf_stats
.
num_flops_per_gpu
self
.
total_read_bytes_per_gpu
+=
perf_stats
.
num_read_bytes_per_gpu
self
.
total_write_bytes_per_gpu
+=
perf_stats
.
num_write_bytes_per_gpu
if
self
.
debug_logging
:
assert
perf_stats
.
debug_stats
is
not
None
self
.
debug_logging
.
observe
(
perf_stats
.
debug_stats
)
def
log
(
self
,
log_fn
=
logger
.
info
,
log_prefix
:
str
=
""
)
->
None
:
if
not
(
self
.
total_num_flops_per_gpu
or
self
.
total_read_bytes_per_gpu
or
self
.
total_write_bytes_per_gpu
):
return
now
=
time
.
monotonic
()
delta_time
=
now
-
self
.
last_log_time
if
delta_time
<=
0.0
:
avg_tflops_per_gpu
=
0.0
avg_gbps_per_gpu
=
0.0
else
:
avg_tflops_per_gpu
=
self
.
total_num_flops_per_gpu
/
delta_time
/
1e12
avg_gbps_per_gpu
=
(
(
self
.
total_read_bytes_per_gpu
+
self
.
total_write_bytes_per_gpu
)
/
delta_time
/
1e9
)
log_fn
(
"%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU"
,
log_prefix
,
avg_tflops_per_gpu
,
avg_gbps_per_gpu
,
)
if
self
.
debug_logging
:
self
.
debug_logging
.
log
(
log_fn
,
log_prefix
,
delta_time
)
self
.
reset
()
## util functions
def
get_required
(
obj
:
object
,
attr
:
str
):
"""Get an attr from an object, or throw a InvalidComponentError if it's not set."""
if
not
hasattr
(
obj
,
attr
):
raise
InvalidComponent
(
f
"Missing required attr
{
attr
}
in config"
)
return
getattr
(
obj
,
attr
)
def
getattr_from_list
(
obj
:
object
,
attrs
:
list
[
str
],
default
:
object
=
None
):
"""Try to get the first attr that exists in the object
from a list of attrs. Otherwise return None."""
for
attr
in
attrs
:
if
hasattr
(
obj
,
attr
):
return
getattr
(
obj
,
attr
)
return
default
vllm/v1/metrics/stats.py
View file @
a810671a
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.v1.metrics.perf
import
PerfStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -186,6 +187,8 @@ class SchedulerStats:
...
@@ -186,6 +187,8 @@ class SchedulerStats:
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
perf_stats
:
PerfStats
|
None
=
None
@
dataclass
@
dataclass
class
RequestStateStats
:
class
RequestStateStats
:
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
a810671a
...
@@ -44,6 +44,32 @@ def _walk_json_for_additional_properties(data: object):
...
@@ -44,6 +44,32 @@ def _walk_json_for_additional_properties(data: object):
_walk_json_for_additional_properties
(
item
)
_walk_json_for_additional_properties
(
item
)
def
has_guidance_unsupported_json_features
(
schema
:
dict
[
str
,
Any
])
->
bool
:
"""Check if JSON schema contains features unsupported by guidance/llguidance."""
def
check_object
(
obj
:
dict
[
str
,
Any
])
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# patternProperties is not supported by llguidance
if
"patternProperties"
in
obj
:
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
process_for_additional_properties
(
def
process_for_additional_properties
(
guide_json
:
str
|
dict
[
str
,
Any
],
guide_json
:
str
|
dict
[
str
,
Any
],
)
->
dict
[
str
,
Any
]:
)
->
dict
[
str
,
Any
]:
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment