Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2039 additions
and
323 deletions
+2039
-323
vllm/tool_parsers/glm4_moe_tool_parser.py
vllm/tool_parsers/glm4_moe_tool_parser.py
+2
-1
vllm/tool_parsers/minimax_m2_tool_parser.py
vllm/tool_parsers/minimax_m2_tool_parser.py
+18
-4
vllm/utils/flashinfer.py
vllm/utils/flashinfer.py
+38
-1
vllm/utils/mem_utils.py
vllm/utils/mem_utils.py
+27
-4
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+75
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+13
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+403
-269
vllm/v1/attention/backends/mamba2_attn.py
vllm/v1/attention/backends/mamba2_attn.py
+27
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+15
-9
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+6
-2
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+36
-10
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+23
-12
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+40
-2
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+2
-1
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+19
-2
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+6
-5
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+16
-1
vllm/v1/metrics/perf.py
vllm/v1/metrics/perf.py
+1244
-0
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+3
-0
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+26
-0
No files found.
vllm/tool_parsers/glm4_moe_tool_parser.py
View file @
a810671a
...
...
@@ -114,7 +114,8 @@ class Glm4MoeModelToolParser(ToolParser):
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
tc_name
,
arguments
=
json
.
dumps
(
arg_dct
)
name
=
tc_name
,
arguments
=
json
.
dumps
(
arg_dct
,
ensure_ascii
=
False
),
),
)
)
...
...
vllm/tool_parsers/minimax_m2_tool_parser.py
View file @
a810671a
...
...
@@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser):
self
.
streaming_request
=
None
# Clear previous tool call history to avoid state pollution
self
.
prev_tool_call_arr
.
clear
()
# Reset streamed args tracking
self
.
streamed_args_for_tool
.
clear
()
def
_extract_name
(
self
,
name_str
:
str
)
->
str
:
"""Extract name from quoted string."""
...
...
@@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser):
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
self
.
current_function_name
,
"arguments"
:
"
{}
"
,
# Placeholder, will be updated later
"arguments"
:
{},
# Placeholder, will be updated later
}
)
# Initialize streamed_args_for_tool for this tool call
if
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Send header with function info
return
DeltaMessage
(
...
...
@@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser):
# Send opening brace if not sent yet
if
self
.
in_function
and
not
self
.
json_started
:
self
.
json_started
=
True
# Update streamed_args_for_tool for opening brace
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
"{"
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
...
...
@@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser):
args
=
parsed_tool
.
function
.
arguments
self
.
prev_tool_call_arr
[
self
.
current_tool_index
][
"arguments"
]
=
args
]
=
json
.
loads
(
args
)
except
Exception
:
pass
# Ignore parsing errors during streaming
...
...
@@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser):
)
]
)
# Update streamed_args_for_tool for closing brace
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
"}"
# Reset state for next tool
self
.
json_closed
=
True
self
.
in_function
=
False
...
...
@@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser):
)
self
.
param_count
+=
1
# Update streamed_args_for_tool for this tool call
if
self
.
current_tool_index
<
len
(
self
.
streamed_args_for_tool
):
self
.
streamed_args_for_tool
[
self
.
current_tool_index
]
+=
(
json_fragment
)
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
...
...
vllm/utils/flashinfer.py
View file @
a810671a
...
...
@@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
)
@
functools
.
cache
def
has_flashinfer_trtllm_fused_moe
()
->
bool
:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if
not
has_flashinfer_moe
():
return
False
required_functions
=
[
(
"flashinfer.fused_moe"
,
"trtllm_fp8_block_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp8_per_tensor_scale_moe"
),
(
"flashinfer.fused_moe"
,
"trtllm_fp4_block_scale_moe"
),
]
for
module_name
,
attr_name
in
required_functions
:
mod
=
_get_submodule
(
module_name
)
if
not
mod
or
not
hasattr
(
mod
,
attr_name
):
return
False
return
True
@
functools
.
cache
def
has_flashinfer_cutlass_fused_moe
()
->
bool
:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
...
...
@@ -288,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if
force_use_trtllm_attention
()
is
False
:
return
False
has_trtllm
=
supports_trtllm_attention
()
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if
has_trtllm
and
num_kv_heads
==
1
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return
has_trtllm
and
(
num_qo_heads
%
num_kv_heads
==
0
)
and
(
num_kv_heads
!=
1
)
def
use_trtllm_attention
(
...
...
@@ -338,6 +366,15 @@ def use_trtllm_attention(
)
return
False
# num_kv_heads=1 is not supported
if
num_kv_heads
==
1
:
if
force_use_trtllm
:
logger
.
warning_once
(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return
False
if
has_spec
and
not
is_prefill
:
# Speculative decoding requires TRTLLM attention for decodes
logger
.
info_once
(
"Using TRTLLM attention (enabled for speculative decoding)."
)
...
...
vllm/utils/mem_utils.py
View file @
a810671a
...
...
@@ -66,27 +66,43 @@ class MemorySnapshot:
torch_memory
:
int
=
0
non_torch_memory
:
int
=
0
timestamp
:
float
=
0.0
device
:
torch
.
types
.
Device
=
None
auto_measure
:
bool
=
True
def
__post_init__
(
self
)
->
None
:
if
self
.
device
is
None
:
from
vllm.platforms
import
current_platform
device_fn
=
current_platform
.
current_device
assert
device_fn
is
not
None
self
.
device_
=
torch
.
device
(
device_fn
())
else
:
self
.
device_
=
torch
.
device
(
self
.
device
)
if
self
.
auto_measure
:
self
.
measure
()
def
measure
(
self
)
->
None
:
from
vllm.platforms
import
current_platform
device
=
self
.
device_
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
().
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
torch_peak
=
torch
.
cuda
.
memory_stats
(
device
).
get
(
"allocated_bytes.all.peak"
,
0
)
self
.
free_memory
,
self
.
total_memory
=
torch
.
cuda
.
mem_get_info
()
self
.
free_memory
,
self
.
total_memory
=
torch
.
cuda
.
mem_get_info
(
device
)
shared_sysmem_device_mem_sms
=
((
8
,
7
),
(
11
,
0
),
(
12
,
1
))
# Orin, Thor, Spark
if
(
current_platform
.
is_cuda
()
and
current_platform
.
get_device_capability
()
in
shared_sysmem_device_mem_sms
and
current_platform
.
get_device_capability
(
device
.
index
)
in
shared_sysmem_device_mem_sms
):
# On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory,
...
...
@@ -106,12 +122,18 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
()
self
.
torch_memory
=
torch
.
cuda
.
memory_reserved
(
device
)
self
.
non_torch_memory
=
self
.
cuda_memory
-
self
.
torch_memory
self
.
timestamp
=
time
.
time
()
def
__sub__
(
self
,
other
:
"MemorySnapshot"
)
->
"MemorySnapshot"
:
if
self
.
device_
!=
other
.
device_
:
raise
ValueError
(
"The two snapshots should be from the same device! "
f
"Found:
{
self
.
device_
}
vs.
{
other
.
device_
}
"
)
return
MemorySnapshot
(
torch_peak
=
self
.
torch_peak
-
other
.
torch_peak
,
free_memory
=
self
.
free_memory
-
other
.
free_memory
,
...
...
@@ -120,6 +142,7 @@ class MemorySnapshot:
torch_memory
=
self
.
torch_memory
-
other
.
torch_memory
,
non_torch_memory
=
self
.
non_torch_memory
-
other
.
non_torch_memory
,
timestamp
=
self
.
timestamp
-
other
.
timestamp
,
device
=
self
.
device_
,
auto_measure
=
False
,
)
...
...
vllm/utils/torch_utils.py
View file @
a810671a
...
...
@@ -24,6 +24,10 @@ else:
ModelConfig
=
object
IntermediateTensors
=
object
import
logging
logger
=
logging
.
getLogger
(
__name__
)
STR_DTYPE_TO_TORCH_DTYPE
=
{
"float32"
:
torch
.
float32
,
...
...
@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
=
{
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8"
:
"fp8_e4m3"
,
}
T
=
TypeVar
(
"T"
)
...
...
@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
return
torch_dtype
def
get_kv_cache_quant_algo_string
(
quant_cfg
:
dict
[
str
,
Any
])
->
str
|
None
:
"""Get the KV cache quantization algorithm string from the quantization config.
Maps various FP8 format names to vLLM's standard cache dtype strings.
Returns None if no kv_cache_quant_algo is specified.
Returns "auto" if the value is not recognized/supported.
"""
# Mapping from model config values to vLLM cache_dtype strings
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
)
if
quant_method
.
startswith
(
"modelopt"
):
quantization_inner
=
quant_cfg
.
get
(
"quantization"
,
quant_cfg
)
# Check if quant config is specified and use kv cache quant algo
kv_algo
=
quantization_inner
.
get
(
"kv_cache_quant_algo"
)
or
quant_cfg
.
get
(
"kv_cache_quant_algo"
)
if
isinstance
(
kv_algo
,
str
):
kv_algo_lower
=
kv_algo
.
lower
()
# Try to map to vLLM's standard format
if
kv_algo_lower
in
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
:
return
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
[
kv_algo_lower
]
else
:
# Unknown/unsupported format - return "auto" as safe fallback
logger
.
warning
(
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
"config. Supported values: %s. Falling back to 'auto'."
,
kv_algo
,
list
(
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP
.
keys
()),
)
return
"auto"
return
None
def
get_kv_cache_quant_algo_dtype
(
quant_cfg
:
dict
[
str
,
Any
])
->
torch
.
dtype
|
None
:
"""Get the KV cache quantization algorithm dtype from the quantization config."""
kv_algo_str
=
get_kv_cache_quant_algo_string
(
quant_cfg
)
if
kv_algo_str
is
not
None
and
kv_algo_str
!=
"auto"
:
# Only convert if we have a valid dtype string (not "auto" fallback)
return
STR_DTYPE_TO_TORCH_DTYPE
[
kv_algo_str
]
return
None
def
resolve_kv_cache_dtype_string
(
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
)
->
str
:
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
Returns the resolved cache_dtype string.
"""
if
kv_cache_dtype
!=
"auto"
:
return
kv_cache_dtype
hf_cfg
=
getattr
(
model_config
,
"hf_config"
,
None
)
if
hf_cfg
is
not
None
:
quant_cfg
=
getattr
(
hf_cfg
,
"quantization_config"
,
None
)
if
quant_cfg
is
not
None
:
kv_algo_str
=
get_kv_cache_quant_algo_string
(
quant_cfg
)
if
kv_algo_str
is
not
None
:
return
kv_algo_str
# Default to auto (will be handled by downstream code)
return
"auto"
def
kv_cache_dtype_str_to_dtype
(
kv_cache_dtype
:
str
,
model_config
:
ModelConfig
)
->
torch
.
dtype
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
a810671a
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
import
copy
from
dataclasses
import
dataclass
from
typing
import
ClassVar
...
...
@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if
get_flash_attn_version
()
==
3
else
AttentionCGSupport
.
UNIFORM_BATCH
)
supports_update_block_table
:
bool
=
True
def
__init__
(
self
,
...
...
@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
)
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
FlashAttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
FlashAttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
new_metadata
.
block_table
=
blk_table
new_metadata
.
slot_mapping
=
slot_mapping
return
new_metadata
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
use_cascade_attention
(
*
args
,
**
kwargs
)
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
a810671a
...
...
@@ -16,6 +16,7 @@ from flashinfer import (
from
flashinfer.decode
import
_get_range_buf
,
trtllm_batch_decode_with_kv_cache
from
flashinfer.prefill
import
trtllm_batch_context_with_kv_cache
from
flashinfer.utils
import
FP4Tensor
from
typing_extensions
import
override
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
...
...
@@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.utils
import
CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
=
2048
*
1024
*
1024
...
...
@@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
paged_kv_indptr_cpu
:
torch
.
Tensor
,
paged_kv_indices
:
torch
.
Tensor
,
paged_kv_last_page_len_cpu
:
torch
.
Tensor
,
prefill_start
:
int
,
page_size
:
int
,
num_qo_heads
:
int
,
dcp_world_size
:
int
,
...
...
@@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_cpu
[
prefill_start
:]
,
paged_kv_last_page_len_cpu
,
num_qo_heads
*
dcp_world_size
,
num_kv_heads
,
head_dim
,
...
...
@@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
@
dataclass
class
F
lashInferMetadata
:
num_actual_tokens
:
int
# Number of tokens excluding padding.
class
F
IPrefill
:
"""Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
# The data type of the query
q_data_type
:
torch
.
dtype
wrapper
:
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
slot_mapping
:
torch
.
Tensor
# For flashinfer trtllm batch decode
@
dataclass
class
FIDecode
:
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
wrapper
:
BatchDecodeWithPagedKVCacheWrapper
@
dataclass
class
TRTLLMPrefill
:
"""Metadata for the TRTLLM prefill pathway."""
block_tables
:
torch
.
Tensor
"""
The slice of the block table tensor corresponding *only* to prefill requests.
Shape: [num_prefills, max_num_blocks_per_seq]
"""
seq_lens
:
torch
.
Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
Shape: [num_prefills]
"""
cum_seq_lens_q
:
torch
.
Tensor
cum_seq_lens_kv
:
torch
.
Tensor
max_q_len
:
int
max_q_len_prefill
:
int
"""
The maximum query length *among prefill requests*.
"""
max_seq_len
:
int
"""The maximum sequence length for KV Cache."""
@
dataclass
class
TRTLLMDecode
:
"""Metadata for the TRTLLM decode pathway."""
block_tables
:
torch
.
Tensor
"""
The slice of the block table tensor corresponding *only* to decode requests.
Shape: [num_decodes, max_num_blocks_per_seq]
"""
seq_lens
:
torch
.
Tensor
block_table_tensor
:
torch
.
Tensor
prefill_use_trtllm
:
bool
decode_use_trtllm
:
bool
"""
The slice of the sequence lengths tensor corresponding *only* to decode requests.
Shape: [num_decodes]
"""
max_seq_len
:
int
"""The maximum sequence length for KV Cache."""
@
dataclass
class
FlashInferMetadata
:
num_actual_tokens
:
int
"""Total number of tokens in the batch (excluding padding)."""
slot_mapping
:
torch
.
Tensor
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
q_data_type
:
torch
.
dtype
# For handling prefill decode split
num_decodes
:
int
num_decode_tokens
:
int
num_prefills
:
int
num_prefill_tokens
:
int
# For cascade attention (CPU for planning).
use_cascade
:
bool
prefill
:
FIPrefill
|
TRTLLMPrefill
|
None
"""
Holds the metadata for the prefill portion of the batch.
Will be `None` if `num_prefill_tokens == 0`.
"""
decode
:
FIDecode
|
TRTLLMDecode
|
None
"""
Holds the metadata for the decode portion of the batch.
Will be `None` if `num_decode_tokens == 0`.
"""
prefill_wrapper
:
(
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
|
None
)
=
None
decode_wrapper
:
BatchDecodeWithPagedKVCacheWrapper
|
None
=
None
cascade_wrapper
:
MultiLevelCascadeAttentionWrapper
|
None
=
None
# --- Special Case: Cascade Attention ---
qo_indptr_gpu
:
torch
.
Tensor
|
None
=
None
paged_kv_indptr_gpu
:
torch
.
Tensor
|
None
=
None
use_cascade
:
bool
"""
If True, the entire batch is a cascade attention call, and the
`prefill` and `decode` fields will both be None.
"""
cascade_wrapper
:
MultiLevelCascadeAttentionWrapper
|
None
class
FlashInferMetadataBuilder
(
AttentionMetadataBuilder
[
FlashInferMetadata
]):
...
...
@@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
dcp_world_size
=
1
self
.
dcp_rank
=
0
self
.
dcp_kv_cache_interleave_size
=
1
self
.
use_dcp
=
self
.
dcp_world_size
>
1
self
.
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
vllm_config
.
parallel_config
...
...
@@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs."
)
# Preparing persistent buffers (device-side)
self
.
paged_kv_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
paged_kv_indices
=
torch
.
zeros
(
max_num_pages
,
# max num pages possible
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
self
.
paged_kv_last_page_len
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# host-side buffer
pin_memory
=
is_pin_memory_available
()
self
.
paged_kv_indptr_cpu
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indptr_np
=
self
.
paged_kv_indptr_cpu
.
numpy
()
self
.
paged_kv_indptr_buffer
=
torch
.
zeros_like
(
self
.
paged_kv_indptr_cpu
,
pin_memory
=
pin_memory
)
self
.
paged_kv_indices_cpu
=
torch
.
zeros
(
max_num_pages
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_last_page_len_cpu
=
torch
.
zeros
(
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
paged_kv_last_page_len_np
=
self
.
paged_kv_last_page_len_cpu
.
numpy
()
# Preparing persistent buffers
self
.
pin_memory
=
is_pin_memory_available
()
self
.
paged_kv_indptr
=
self
.
_make_buffer
(
max_num_reqs
+
1
)
self
.
paged_kv_indptr_cpu_buffer
=
torch
.
zeros_like
(
self
.
paged_kv_indptr
.
cpu
,
pin_memory
=
self
.
pin_memory
)
# Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
self
.
paged_kv_indices
=
self
.
_make_buffer
(
max_num_pages
)
self
.
paged_kv_last_page_len
=
self
.
_make_buffer
(
max_num_reqs
)
if
self
.
head_dim
==
256
and
current_platform
.
is_device_capability_family
(
100
):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
...
...
@@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64."
)
def
_make_buffer
(
self
,
*
size
:
int
|
torch
.
SymInt
,
dtype
:
torch
.
dtype
=
torch
.
int32
)
->
CpuGpuBuffer
:
return
CpuGpuBuffer
(
*
size
,
dtype
=
dtype
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
with_numpy
=
True
,
)
@
override
# type: ignore[misc]
@
classmethod
def
get_cudagraph_support
(
cls
:
type
[
"FlashInferMetadataBuilder"
],
...
...
@@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
,
)
->
BatchPrefillWithPagedKVCacheWrapper
|
BatchDCPPrefillWrapper
:
if
self
.
_prefill_wrapper
is
None
:
if
self
.
dcp_world_size
>
1
:
if
self
.
use_dcp
:
self
.
_prefill_wrapper
=
BatchDCPPrefillWrapper
(
workspace_buffer
=
self
.
_get_workspace_buffer
(),
)
...
...
@@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if
decode_wrapper
is
None
:
if
use_cudagraph
:
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
paged_kv_indices
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
batch_size
]
paged_kv_indptr
=
self
.
paged_kv_indptr
.
gpu
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
paged_kv_indices
.
gpu
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
gpu
[:
batch_size
]
else
:
paged_kv_indptr
=
None
paged_kv_indices
=
None
...
...
@@ -661,99 +718,43 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
return
self
.
_cascade_wrapper
def
build
(
def
_compute_flashinfer_kv_metadata
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
FlashInferMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
page_size
=
self
.
page_size
max_q_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
common_attn_metadata
.
max_seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
if
self
.
dcp_world_size
>
1
:
if
num_prefills
>
0
:
qo_indptr_prefill_cpu
=
(
qo_indptr_cpu
[
num_decodes
:]
-
qo_indptr_cpu
[
num_decodes
]
)
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
seq_lens_cpu
[
num_decodes
:]
=
(
seq_lens_cpu
[
num_decodes
:]
-
query_lens_prefill_cpu
)
seq_lens_cpu
=
get_dcp_local_seq_lens
(
seq_lens_cpu
,
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
dcp_kv_cache_interleave_size
,
)
seq_lens_np
=
seq_lens_cpu
.
numpy
()
num_blocks_np
=
(
seq_lens_np
+
(
page_size
-
1
))
//
page_size
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# Grab the blocks of the shared prefix from the first request.
assert
common_prefix_len
%
page_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indices_cpu
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len_cpu
=
torch
.
tensor
(
[
page_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
num_blocks_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table_tensor
:
torch
.
Tensor
,
num_reqs
:
int
,
page_size
:
int
,
)
->
torch
.
Tensor
:
"""
Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
attention.
# Remove the blocks of the shared prefix from all requests.
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
num_blocks_np
-=
num_common_kv_blocks
else
:
shared_qo_indptr_cpu
=
None
shared_kv_page_indptr_cpu
=
None
shared_kv_page_indices_cpu
=
None
shared_kv_last_page_len_cpu
=
None
Results are stored in self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len buffers.
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
"""
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np
.
cumsum
(
num_blocks_np
,
dtype
=
np
.
int32
,
out
=
self
.
paged_kv_indptr
_
np
[
1
:
num_reqs
+
1
],
out
=
self
.
paged_kv_indptr
.
np
[
1
:
num_reqs
+
1
],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self
.
paged_kv_indptr_buffer
[:
num_reqs
+
1
]
=
self
.
paged_kv_indptr
_
cpu
[
self
.
paged_kv_indptr_
cpu_
buffer
[:
num_reqs
+
1
]
=
self
.
paged_kv_indptr
.
cpu
[
:
num_reqs
+
1
]
paged_kv_indptr
=
self
.
paged_kv_indptr
[:
num_reqs
+
1
]
paged_kv_indptr
=
self
.
paged_kv_indptr
.
gpu
[:
num_reqs
+
1
]
paged_kv_indptr
.
copy_
(
self
.
paged_kv_indptr_buffer
[:
num_reqs
+
1
],
non_blocking
=
True
self
.
paged_kv_indptr_
cpu_
buffer
[:
num_reqs
+
1
],
non_blocking
=
True
)
# write self.paged_kv_indices inplace
num_actual_pages
=
self
.
paged_kv_indptr
_
np
[
num_reqs
]
paged_kv_indices
=
self
.
paged_kv_indices
[:
num_actual_pages
]
num_actual_pages
=
self
.
paged_kv_indptr
.
np
[
num_reqs
]
paged_kv_indices
=
self
.
paged_kv_indices
.
gpu
[:
num_actual_pages
]
_copy_page_indices_kernel
[(
num_reqs
,)](
paged_kv_indices
,
block_table_tensor
,
...
...
@@ -764,12 +765,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np
=
seq_lens_np
%
page_size
self
.
paged_kv_last_page_len
_
np
[:
num_reqs
]
=
np
.
where
(
self
.
paged_kv_last_page_len
.
np
[:
num_reqs
]
=
np
.
where
(
(
paged_kv_last_page_len_np
==
0
)
&
(
seq_lens_np
!=
0
),
page_size
,
paged_kv_last_page_len_np
,
)
return
paged_kv_indices
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
FlashInferMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
page_size
=
self
.
page_size
max_seq_len
=
common_attn_metadata
.
max_seq_len
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
qo_indptr
=
common_attn_metadata
.
query_start_loc
qo_indptr_cpu
=
common_attn_metadata
.
query_start_loc_cpu
# Step 1: Decide which dispatch modes to use:
# - Cascade attention (distinct mode)
# - Prefill (FI native or TRTLLM)
# - Decode (FI native or TRTLLM)
use_cascade
=
common_prefix_len
>
0
uses_spec_reorder
=
self
.
reorder_batch_threshold
>
1
prefill_use_trtllm
=
use_trtllm_attention
(
self
.
num_qo_heads
,
...
...
@@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
use_trtllm_decode_attention
and
self
.
dcp_world_size
<=
1
)
if
not
(
prefill_use_trtllm
and
decode_use_trtllm
):
all_uses_trtllm
=
(
num_prefills
==
0
or
prefill_use_trtllm
)
and
(
num_decodes
==
0
or
decode_use_trtllm
)
is_only_trtllm_decode
=
num_prefills
==
0
and
(
num_decodes
>
0
and
decode_use_trtllm
)
if
not
all_uses_trtllm
:
if
self
.
has_sinks
:
raise
NotImplementedError
(
"FlashInfer backend currently does not support attention "
...
...
@@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# fall back to model dtype.
self
.
q_data_type
=
self
.
model_config
.
dtype
# Step 2: Initialize the output metadata
# Leave prefill/decode/cascade_wrapper empty, to be populated
# case by case depending on the batch contents and backend selection.
attn_metadata
=
FlashInferMetadata
(
num_actual_tokens
=
num_actual_tokens
,
q_data_type
=
self
.
q_data_type
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
max_q_len
=
max_q_len
,
max_q_len_prefill
=
max_q_len
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table_tensor
=
block_table_tensor
,
prefill_use_trtllm
=
prefill_use_trtllm
,
decode_use_trtllm
=
decode_use_trtllm
,
q_data_type
=
self
.
q_data_type
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
use_cascade
=
use_cascade
,
prefill
=
None
,
decode
=
None
,
cascade_wrapper
=
None
,
)
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr_cpu
[:
1
+
num_reqs
]
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len_cpu
[:
num_reqs
]
# Guard access to seq_lens_cpu, which may not always be needed
# and can be expensive to retrieve in async mode.
needs_seq_lens_cpu
=
self
.
use_dcp
or
use_cascade
or
not
is_only_trtllm_decode
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
if
needs_seq_lens_cpu
else
None
seq_lens_np
=
seq_lens_cpu
.
numpy
()
if
seq_lens_cpu
is
not
None
else
None
num_blocks_np
=
(
(
seq_lens_np
+
(
page_size
-
1
))
//
page_size
if
seq_lens_np
is
not
None
else
None
)
# Adjust seq_lens_cpu for DCP
if
self
.
use_dcp
:
assert
seq_lens_cpu
is
not
None
if
num_prefills
>
0
:
qo_indptr_prefill_cpu
=
(
qo_indptr_cpu
[
num_decodes
:]
-
qo_indptr_cpu
[
num_decodes
]
)
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
seq_lens_cpu
[
num_decodes
:]
=
(
seq_lens_cpu
[
num_decodes
:]
-
query_lens_prefill_cpu
)
seq_lens_cpu
=
get_dcp_local_seq_lens
(
seq_lens_cpu
,
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
dcp_kv_cache_interleave_size
,
)
# Adjust num_block_np for cascade attention
if
use_cascade
:
assert
num_blocks_np
is
not
None
assert
common_prefix_len
%
page_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
page_size
num_blocks_np
-=
num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices
=
use_cascade
or
not
is_only_trtllm_decode
if
needs_paged_kv_indices
:
assert
num_blocks_np
is
not
None
assert
seq_lens_np
is
not
None
paged_kv_indices
=
self
.
_compute_flashinfer_kv_metadata
(
num_blocks_np
,
seq_lens_np
,
block_table_tensor
,
num_reqs
,
page_size
,
)
else
:
paged_kv_indices
=
None
# Early-out for cascade attention
if
use_cascade
:
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks
=
common_prefix_len
//
page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indptr_cpu
=
torch
.
tensor
(
[
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
shared_kv_page_indices_cpu
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len_cpu
=
torch
.
tensor
(
[
page_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor
=
block_table_tensor
[:,
num_common_kv_blocks
:]
num_blocks_np
-=
num_common_kv_blocks
assert
paged_kv_indices
is
not
None
paged_kv_indptr_cpu
=
self
.
paged_kv_indptr
.
cpu
[:
1
+
num_reqs
]
paged_kv_last_page_len_cpu
=
self
.
paged_kv_last_page_len
.
cpu
[:
num_reqs
]
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
[
shared_qo_indptr_cpu
,
qo_indptr_cpu
],
...
...
@@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
)
else
:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
num_prefills
=
attn_metadata
.
num_prefills
num_decodes
=
attn_metadata
.
num_decodes
if
num_prefills
>
0
:
# Decodes are first so prefills start after the last decode
prefill_start
=
num_decodes
attn_metadata
.
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
assert
qo_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
paged_kv_indptr_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
+
1
assert
(
paged_kv_last_page_len_cpu
[
prefill_start
:].
shape
[
0
]
==
num_prefills
return
attn_metadata
# Step 3: Handle prefill and decode pathways case by case
## PREFILL PATHWAY
if
num_prefills
>
0
:
# Slices for shared prefill metadata
prefill_start
=
num_decodes
qo_indptr_prefill_cpu
=
(
qo_indptr_cpu
[
prefill_start
:]
-
qo_indptr_cpu
[
prefill_start
]
)
assert
qo_indptr_prefill_cpu
.
shape
[
0
]
==
num_prefills
+
1
if
prefill_use_trtllm
:
# Create GPU versions
qo_indptr_prefill_gpu
=
(
qo_indptr
[
prefill_start
:]
-
qo_indptr
[
prefill_start
]
)
paged_kv_indptr_prefill_gpu
=
self
.
paged_kv_indptr
.
gpu
[
prefill_start
:
num_reqs
+
1
]
# Compute max_q_len for prefill requests
query_lens_prefill_cpu
=
(
qo_indptr_prefill_cpu
[
1
:]
-
qo_indptr_prefill_cpu
[:
-
1
]
)
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr_cpu
=
(
qo_indptr_cpu
[
prefill_start
:]
-
qo_indptr_cpu
[
prefill_start
]
max_q_len_prefill
=
int
(
query_lens_prefill_cpu
.
max
().
item
())
attn_metadata
.
prefill
=
TRTLLMPrefill
(
block_tables
=
block_table_tensor
[
prefill_start
:],
seq_lens
=
seq_lens
[
prefill_start
:],
cum_seq_lens_q
=
qo_indptr_prefill_gpu
,
cum_seq_lens_kv
=
paged_kv_indptr_prefill_gpu
,
max_q_len
=
max_q_len_prefill
,
max_seq_len
=
max_seq_len
,
)
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
[
prefill_start
:]
# Recompute max_q_len for the slice of requests we are using
# for prefills. This can be different from max_q_len when
# we have a non-uniform batch with some short decodes offloaded
# to the prefill pathway
query_lens_prefill
=
qo_indptr_cpu
[
1
:]
-
qo_indptr_cpu
[:
-
1
]
attn_metadata
.
max_q_len_prefill
=
int
(
query_lens_prefill
.
max
().
item
())
if
not
attn_metadata
.
prefill_use_trtllm
:
if
self
.
dcp_world_size
>
1
:
assert
isinstance
(
attn_metadata
.
prefill_wrapper
,
BatchDCPPrefillWrapper
)
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
=
qo_indptr_cpu
,
paged_kv_indptr_cpu
=
paged_kv_indptr_cpu
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
paged_kv_last_page_len_cpu
,
prefill_start
=
prefill_start
,
page_size
=
self
.
page_size
,
num_qo_heads
=
self
.
num_qo_heads
,
dcp_world_size
=
self
.
dcp_world_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
prefill_fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
assert
isinstance
(
attn_metadata
.
prefill_wrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
attn_metadata
.
prefill_wrapper
.
plan
(
qo_indptr_cpu
,
paged_kv_indptr_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_cpu
[
prefill_start
:],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
prefill_wrapper
=
self
.
_get_prefill_wrapper
()
# Slicing CPU buffers that are only needed for FI native prefills
paged_kv_last_page_len_prefill_cpu
=
self
.
paged_kv_last_page_len
.
cpu
[
prefill_start
:
num_reqs
]
assert
paged_kv_last_page_len_prefill_cpu
.
shape
[
0
]
==
num_prefills
paged_kv_indptr_prefill_cpu
=
self
.
paged_kv_indptr
.
cpu
[
prefill_start
:
num_reqs
+
1
]
assert
paged_kv_indptr_prefill_cpu
.
shape
[
0
]
==
num_prefills
+
1
if
self
.
use_dcp
:
assert
isinstance
(
prefill_wrapper
,
BatchDCPPrefillWrapper
)
prefill_wrapper
.
plan
(
qo_indptr_cpu
=
qo_indptr_prefill_cpu
,
paged_kv_indptr_cpu
=
paged_kv_indptr_prefill_cpu
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len_cpu
=
paged_kv_last_page_len_prefill_cpu
,
page_size
=
self
.
page_size
,
num_qo_heads
=
self
.
num_qo_heads
,
dcp_world_size
=
self
.
dcp_world_size
,
num_kv_heads
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
prefill_fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
else
:
attn_metadata
.
qo_indptr_gpu
=
qo_indptr_cpu
.
to
(
self
.
device
,
non_blocking
=
True
assert
isinstance
(
prefill_wrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
)
attn_metadata
.
paged_kv_indptr_gpu
=
paged_kv_indptr_cpu
.
to
(
self
.
device
,
non_blocking
=
True
prefill_wrapper
.
plan
(
qo_indptr_prefill_cpu
,
paged_kv_indptr_prefill_cpu
,
paged_kv_indices
,
paged_kv_last_page_len_prefill_cpu
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
causal
=
True
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv_data_type
=
self
.
kv_cache_dtype
,
fixed_split_size
=
self
.
prefill_fixed_split_size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
attn_metadata
.
prefill
=
FIPrefill
(
wrapper
=
prefill_wrapper
)
if
num_decodes
>
0
:
## DECODE PATHWAY
if
num_decodes
>
0
:
if
decode_use_trtllm
:
assert
num_decode_tokens
%
num_decodes
==
0
,
(
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata
.
decode
=
TRTLLMDecode
(
block_tables
=
block_table_tensor
[:
num_decodes
],
seq_lens
=
seq_lens
[:
num_decodes
],
max_seq_len
=
max_seq_len
,
)
else
:
pure_decode
=
num_prefills
==
0
use_cudagraph
=
(
self
.
enable_cuda_graph
...
...
@@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
num_input_tokens
=
num_decode_tokens
attn_metadata
.
decode_wrapper
=
self
.
_get_decode_wrapper
(
decode_wrapper
=
self
.
_get_decode_wrapper
(
num_input_tokens
,
use_cudagraph
)
if
not
attn_metadata
.
decode_use_trtllm
:
# Use the persistent buffer with padding length,
# instead of the sa
me
ad
dress but chunked version
# in atten_metadata when using cudagraph.
fast_plan_decode
(
attn_metadata
.
decode_wrapper
,
self
.
paged_kv_ind
ptr_cpu
[:
num_input_tokens
+
1
]
,
paged_kv_
indices
,
self
.
paged_kv_last_page
_len_cpu
[:
num_input_tokens
],
seq_lens_cpu
[:
num_input_tokens
]
,
self
.
num_
qo
_heads
*
self
.
dcp_world_size
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's
pos
encoding
and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q
_data_type
=
self
.
q_data_
type
,
kv_data_type
=
self
.
kv_cache_dtyp
e
,
fixed
_split_
size
=
self
.
d
ecode_fixed
_split_
size
,
disable_split_kv
=
self
.
disable_split_kv
,
)
# Use the persistent buffer with padding length,
# instead of the same address but chunked version
# in atten_
me
t
ad
ata when using cudagraph.
fast_plan_decode
(
decode_wrapper
,
self
.
paged_kv_indptr
.
cpu
[:
num_input_tokens
+
1
]
,
paged_kv_ind
ices
,
self
.
paged_kv_
last_page_len
.
cpu
[:
num_input_tokens
]
,
seq
_len
s
_cpu
[:
num_input_tokens
],
self
.
num_qo_heads
*
self
.
dcp_world_size
,
self
.
num_
kv
_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos
_
encoding
_mode
=
"NONE"
,
sm_scale
=
self
.
sm_scale
,
window_left
=
self
.
window_left
,
logits_soft_cap
=
self
.
logits_soft_cap
,
q_data_type
=
self
.
q_data_type
,
kv
_data_type
=
self
.
kv_cache_d
type
,
fixed_split_size
=
self
.
decode_fixed_split_siz
e
,
disable
_split_
kv
=
self
.
d
isable
_split_
kv
,
)
attn_metadata
.
decode
=
FIDecode
(
wrapper
=
decode_wrapper
)
return
attn_metadata
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
...
...
@@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
if
self
.
bmm2_scale
is
None
:
self
.
bmm2_scale
=
layer
.
_v_scale_float
prefill_use_trtllm
=
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
decode_use_trtllm
=
isinstance
(
attn_metadata
.
decode
,
TRTLLMDecode
)
# The attn+quant fusion happens when output_scale is provided.
if
output_scale
is
None
:
assert
output_block_scale
is
None
,
(
...
...
@@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
assert
attn_metadata
.
q_data_type
==
FP8_DTYPE
,
(
"Query must be FP8 when attn+quant fusion happened."
)
assert
(
attn_metadata
.
prefill_use_trtllm
and
attn_metadata
.
decode_use_trtllm
assert
(
attn_metadata
.
num_prefills
==
0
or
prefill_use_trtllm
)
and
(
attn_metadata
.
num_decodes
==
0
or
decode_use_trtllm
),
"Must use TRT-LLM attn"
if
output
.
dtype
==
FP8_DTYPE
:
...
...
@@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
num_decodes
=
attn_metadata
.
num_decodes
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
use_dcp
=
self
.
dcp_world_size
>
1
# Regular attention (common case).
# Decodes are at the front and prefills are at the back.
if
num_prefill_tokens
>
0
:
prefill_wrapper
=
attn_metadata
.
prefill_wrapper
prefill_query
=
query
[
num_decode_tokens
:]
assert
prefill_query
.
shape
[
0
]
==
num_prefill_tokens
assert
prefill_wrapper
is
not
None
if
not
attn_metadata
.
prefill_use_trtllm
:
if
self
.
dcp_world_size
>
1
:
if
not
prefill_use_trtllm
:
assert
isinstance
(
attn_metadata
.
prefill
,
FIPrefill
)
prefill_wrapper
=
attn_metadata
.
prefill
.
wrapper
assert
prefill_wrapper
is
not
None
if
use_dcp
:
assert
isinstance
(
prefill_wrapper
,
BatchDCPPrefillWrapper
)
assert
prefill_wrapper
.
_context
.
_window_left
==
self
.
window_left
assert
prefill_wrapper
.
_context
.
_logits_soft_cap
==
(
...
...
@@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
out
=
output
[
num_decode_tokens
:],
)
else
:
assert
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
# prefill_query may be non-contiguous
prefill_query
=
prefill_query
.
contiguous
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
block_tables_prefill
=
attn_metadata
.
block_table
_tensor
[
num_decodes
:]
seq_lens_prefill
=
attn_metadata
.
seq_lens
[
num_decodes
:]
block_tables_prefill
=
attn_metadata
.
prefill
.
block_table
s
seq_lens_prefill
=
attn_metadata
.
prefill
.
seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert
get_kv_cache_layout
()
==
"HND"
...
...
@@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer
=
workspace_buffer
,
block_tables
=
mock_block_table
,
seq_lens
=
seq_lens_prefill
,
max_q_len
=
attn_metadata
.
max_q_len
_prefill
,
max_kv_len
=
attn_metadata
.
max_seq_len
,
max_q_len
=
attn_metadata
.
prefill
.
max_q_len
,
max_kv_len
=
attn_metadata
.
prefill
.
max_seq_len
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
batch_size
=
attn_metadata
.
num_prefills
,
cum_seq_lens_q
=
attn_metadata
.
qo_indptr_gpu
,
cum_seq_lens_kv
=
attn_metadata
.
p
aged_kv_indptr_gpu
,
cum_seq_lens_q
=
attn_metadata
.
prefill
.
cum_seq_lens_q
,
cum_seq_lens_kv
=
attn_metadata
.
p
refill
.
cum_seq_lens_kv
,
window_left
=
self
.
window_left
,
sinks
=
self
.
sinks
,
o_sf_scale
=
self
.
o_sf_scale
,
...
...
@@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
)
if
num_decode_tokens
>
0
:
decode_wrapper
=
attn_metadata
.
decode_wrapper
decode_query
=
query
[:
num_decode_tokens
]
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_wrapper
is
not
None
if
not
attn_metadata
.
decode_use_trtllm
:
if
not
decode_use_trtllm
:
assert
isinstance
(
attn_metadata
.
decode
,
FIDecode
)
decode_wrapper
=
attn_metadata
.
decode
.
wrapper
assert
decode_wrapper
is
not
None
assert
decode_wrapper
.
_window_left
==
self
.
window_left
assert
decode_wrapper
.
_logits_soft_cap
==
(
self
.
logits_soft_cap
or
0.0
)
assert
decode_wrapper
.
_sm_scale
==
self
.
scale
if
se
lf
.
dcp_world_size
>
1
:
if
u
se
_dcp
:
decode_query
=
get_dcp_group
().
all_gather
(
decode_query
.
contiguous
(),
dim
=-
2
)
...
...
@@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
)
else
:
# decode_query may be non-contiguous
assert
isinstance
(
attn_metadata
.
decode
,
TRTLLMDecode
)
decode_query
=
decode_query
.
contiguous
()
workspace_buffer
=
_get_trtllm_gen_workspace_buffer
()
block_tables_decode
=
attn_metadata
.
block_table_tensor
[
:
num_decode_tokens
]
seq_lens_decode
=
attn_metadata
.
seq_lens
[:
num_decode_tokens
]
block_tables_decode
=
attn_metadata
.
decode
.
block_tables
seq_lens_decode
=
attn_metadata
.
decode
.
seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert
get_kv_cache_layout
()
==
"HND"
...
...
@@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables_decode
,
seq_lens
=
seq_lens_decode
,
max_seq_len
=
attn_metadata
.
max_seq_len
,
max_seq_len
=
attn_metadata
.
decode
.
max_seq_len
,
bmm1_scale
=
self
.
bmm1_scale
,
bmm2_scale
=
self
.
bmm2_scale
,
window_left
=
self
.
window_left
,
...
...
vllm/v1/attention/backends/mamba2_attn.py
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
itertools
from
dataclasses
import
dataclass
...
...
@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class
Mamba2AttentionMetadataBuilder
(
BaseMambaAttentionMetadataBuilder
[
Mamba2AttentionMetadata
]
):
supports_update_block_table
:
bool
=
True
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
...
...
@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p
=
num_computed_tokens_p
,
)
return
attn_metadata
def
update_block_table
(
self
,
metadata
:
Mamba2AttentionMetadata
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
Mamba2AttentionMetadata
:
new_metadata
=
copy
.
copy
(
metadata
)
prefix_caching
=
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
state_indices_t
=
blk_table
if
prefix_caching
else
blk_table
[:,
0
]
num_reqs
=
blk_table
.
shape
[
0
]
# For CUDA graphs, copy to persistent buffer
if
(
metadata
.
num_prefills
==
0
and
num_reqs
<=
self
.
decode_cudagraph_max_bs
and
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
()
):
persistent_state_indices_t
=
self
.
state_indices_tensor
[:
num_reqs
]
persistent_state_indices_t
.
copy_
(
state_indices_t
,
non_blocking
=
True
)
state_indices_t
=
persistent_state_indices_t
new_metadata
.
state_indices_tensor
=
state_indices_t
return
new_metadata
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
a810671a
...
...
@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
QueryLenSupport
,
)
from
vllm.v1.attention.backends.utils
import
AttentionCGSupport
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr
:
torch
.
Tensor
|
None
=
None
# The dtype of MLA out tensor
attn_out_dtype
:
torch
.
dtype
=
torch
.
bfloat16
# The max query output length: int
max_qo_len
:
int
|
None
=
None
class
AiterMLAMetadata
(
MLACommonMetadata
[
AiterMLADecodeMetadata
]):
...
...
@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
# TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
)
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
AttentionCGSupport
.
UNIFORM_BATCH
query_len_support
:
ClassVar
[
QueryLenSupport
]
=
QueryLenSupport
.
UNIFORM
def
__init__
(
self
,
...
...
@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
qo_indptr
=
torch
.
arange
(
0
,
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
self
.
qo_indptr
=
torch
.
zeros
(
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
def
_build_decode
(
...
...
@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device
.
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
),
]
)
qo_len
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
max_qo_len
=
qo_len
.
max
().
item
()
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
num_actual_pages
=
paged_kv_indices
.
size
(
0
)
...
...
@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self
.
paged_kv_last_page_len
[
num_reqs
:].
fill_
(
1
)
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
[:
num_reqs
]
self
.
qo_indptr
[:
1
+
num_reqs
].
copy_
(
query_start_loc_device
,
non_blocking
=
True
)
self
.
qo_indptr
[
1
+
num_reqs
:]
=
query_start_loc_device
[
-
1
]
qo_indptr
=
self
.
qo_indptr
[:
1
+
num_reqs
]
else
:
...
...
@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len
=
paged_kv_last_page_len
,
qo_indptr
=
qo_indptr
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
max_qo_len
=
max_qo_len
,
attn_out_dtype
=
self
.
decode_attn_out_dtype
,
)
...
...
@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo
=
1
rocm_aiter_ops
.
mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
decode
.
qo_indptr
,
max_seqlen_qo
,
attn_metadata
.
decode
.
max_qo_len
,
attn_metadata
.
decode
.
paged_kv_indptr
,
attn_metadata
.
decode
.
paged_kv_indices
,
attn_metadata
.
decode
.
paged_kv_last_page_len
,
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
a810671a
...
...
@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class
RocmAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
...
...
@@ -165,7 +169,7 @@ class RocmAttentionBackend(AttentionBackend):
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
"Set --attention-
config.
backend=FLEX_ATTENTION to use "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
...
...
vllm/v1/attention/backends/utils.py
View file @
a810671a
...
...
@@ -4,6 +4,7 @@ import abc
import
enum
import
functools
from
abc
import
abstractmethod
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
,
field
,
fields
,
make_dataclass
from
typing
import
(
TYPE_CHECKING
,
...
...
@@ -201,10 +202,11 @@ def _make_metadata_with_slice(
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the "middle" request has tokens in both ubatches, we have to split it.
# If ubatch_slice is the first ubatch then we will be splitting the last
# request. If it's the second microbatch, then we will be splitting the
# first request
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request
=
first_tok
>
start_locs
[
first_req
]
splits_last_request
=
last_tok
<
start_locs
[
last_req
+
1
]
-
1
...
...
@@ -225,7 +227,10 @@ def _make_metadata_with_slice(
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
if
splits_last_request
:
tokens_skipped
=
query_start_loc_cpu
[
-
1
]
-
token_slice
.
stop
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped
=
start_locs
[
last_req
+
1
]
-
token_slice
.
stop
query_start_loc
[
-
1
]
-=
tokens_skipped
query_start_loc_cpu
[
-
1
]
-=
tokens_skipped
...
...
@@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold
:
int
|
None
=
None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table
:
bool
=
False
@
abstractmethod
def
__init__
(
...
...
@@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise
NotImplementedError
def
update_block_table
(
self
,
metadata
:
M
,
blk_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
M
:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise
NotImplementedError
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
...
...
@@ -599,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
block_size
:
int
=
0
,
)
->
CommonAttentionMetadata
:
)
->
tuple
[
CommonAttentionMetadata
,
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]]
:
query_start_loc_np
=
common_attn_metadata
.
query_start_loc_cpu
.
numpy
()
seq_lens_np
=
common_attn_metadata
.
seq_lens_cpu
.
numpy
()
block_table
=
common_attn_metadata
.
block_table_tensor
...
...
@@ -711,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf.
batch_indices_torch
=
torch
.
from_numpy
(
batch_indices
)
block_indices_torch
=
torch
.
from_numpy
(
block_indices
)
block_table_local
=
block_table
[
batch_indices_torch
,
block_indices_torch
].
view
(
virtual_batches
,
-
1
)
# Save as a lambda so we can return this for update_block_table
make_block_table
=
lambda
block_table
:
block_table
[
batch_indices_torch
,
block_indices_torch
].
view
(
virtual_batches
,
-
1
)
block_table_local
=
make_block_table
(
block_table
)
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
...
...
@@ -732,7 +758,7 @@ def make_local_attention_virtual_batches(
causal
=
True
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
)
)
,
make_block_table
def
make_kv_sharing_fast_prefill_common_attn_metadata
(
...
...
vllm/v1/core/sched/scheduler.py
View file @
a810671a
...
...
@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from
vllm.v1.core.sched.utils
import
check_stop
,
remove_all
from
vllm.v1.engine
import
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.perf
import
ModelMetrics
,
PerfStats
from
vllm.v1.metrics.stats
import
(
PrefixCacheStats
,
SchedulerStats
,
...
...
@@ -187,6 +188,12 @@ class Scheduler(SchedulerInterface):
if
self
.
is_encoder_decoder
else
EncoderCacheManager
(
cache_size
=
encoder_cache_size
)
)
# For encoder-decoder models, allocate the maximum number of tokens for Cross
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self
.
_num_encoder_max_input_tokens
=
(
MULTIMODAL_REGISTRY
.
get_encdec_max_encoder_len
(
vllm_config
.
model_config
)
)
speculative_config
=
vllm_config
.
speculative_config
self
.
use_eagle
=
False
...
...
@@ -213,6 +220,10 @@ class Scheduler(SchedulerInterface):
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
use_v2_model_runner
=
envs
.
VLLM_USE_V2_MODEL_RUNNER
self
.
perf_metrics
:
ModelMetrics
|
None
=
None
if
self
.
log_stats
and
vllm_config
.
observability_config
.
enable_mfu_metrics
:
self
.
perf_metrics
=
ModelMetrics
(
vllm_config
)
def
schedule
(
self
)
->
SchedulerOutput
:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
...
...
@@ -568,17 +579,11 @@ class Scheduler(SchedulerInterface):
0
if
request
.
num_computed_tokens
==
0
else
self
.
num_lookahead_tokens
)
# Determine if we need to allocate cross-attention blocks.
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens
=
(
self
.
scheduler_config
.
max_num_encoder_input_tokens
)
else
:
num_encoder_tokens
=
0
num_encoder_tokens
=
(
self
.
_num_encoder_max_input_tokens
if
self
.
is_encoder_decoder
and
request
.
has_encoder_inputs
else
0
)
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
request
,
...
...
@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface):
kv_connector_output
=
model_runner_output
.
kv_connector_output
cudagraph_stats
=
model_runner_output
.
cudagraph_stats
perf_stats
:
PerfStats
|
None
=
None
if
self
.
perf_metrics
and
self
.
perf_metrics
.
is_enabled
():
perf_stats
=
self
.
perf_metrics
.
get_step_perf_stats_per_gpu
(
scheduler_output
)
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
kv_connector_stats
:
KVConnectorStats
|
None
=
(
...
...
@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface):
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
,
cudagraph_stats
spec_decoding_stats
,
kv_connector_stats
,
cudagraph_stats
,
perf_stats
)
)
is
not
None
:
# Return stats to only one of the front-ends.
...
...
@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
,
kv_connector_stats
:
KVConnectorStats
|
None
=
None
,
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
,
perf_stats
:
PerfStats
|
None
=
None
,
)
->
SchedulerStats
|
None
:
if
not
self
.
log_stats
:
return
None
...
...
@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats
=
spec_stats
,
kv_connector_stats
=
connector_stats_payload
,
cudagraph_stats
=
cudagraph_stats
,
perf_stats
=
perf_stats
,
)
def
make_spec_decoding_stats
(
...
...
vllm/v1/engine/core.py
View file @
a810671a
...
...
@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
FinishReason
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
UtilityOutput
,
...
...
@@ -923,6 +925,13 @@ class EngineCoreProc(EngineCore):
# Post-step hook.
self
.
post_step
(
model_executed
)
# If no model execution happened but there are waiting requests
# (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
# background threads (like NIXL handshake) to make progress.
# Without this, the tight polling loop can starve background threads.
if
not
model_executed
and
self
.
scheduler
.
has_unfinished_requests
():
time
.
sleep
(
0.001
)
return
model_executed
def
_handle_client_request
(
...
...
@@ -1048,9 +1057,14 @@ class EngineCoreProc(EngineCore):
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
request
:
Any
if
request_type
==
EngineCoreRequestType
.
ADD
:
request
=
add_request_decoder
.
decode
(
data_frames
)
request
=
self
.
preprocess_add_request
(
request
)
req
:
EngineCoreRequest
=
add_request_decoder
.
decode
(
data_frames
)
try
:
request
=
self
.
preprocess_add_request
(
req
)
except
Exception
:
self
.
_handle_request_preproc_error
(
req
)
continue
else
:
request
=
generic_decoder
.
decode
(
data_frames
)
...
...
@@ -1134,6 +1148,30 @@ class EngineCoreProc(EngineCore):
# Limit the number of buffers to reuse.
reuse_buffers
.
append
(
buffer
)
def
_handle_request_preproc_error
(
self
,
request
:
EngineCoreRequest
)
->
None
:
"""Log and return a request-scoped error response for exceptions raised
from the add request preprocessing in the input socket processing thread.
"""
logger
.
exception
(
"Unexpected error pre-processing request %s"
,
request
.
request_id
)
self
.
output_queue
.
put_nowait
(
(
request
.
client_index
,
EngineCoreOutputs
(
engine_index
=
self
.
engine_index
,
finished_requests
=
{
request
.
request_id
},
outputs
=
[
EngineCoreOutput
(
request_id
=
request
.
request_id
,
new_token_ids
=
[],
finish_reason
=
FinishReason
.
ERROR
,
)
],
),
)
)
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
...
...
vllm/v1/engine/core_client.py
View file @
a810671a
...
...
@@ -269,7 +269,8 @@ class InprocClient(EngineCoreClient):
self
.
engine_core
=
EngineCore
(
*
args
,
**
kwargs
)
def
get_output
(
self
)
->
EngineCoreOutputs
:
outputs
,
_
=
self
.
engine_core
.
step_fn
()
outputs
,
model_executed
=
self
.
engine_core
.
step_fn
()
self
.
engine_core
.
post_step
(
model_executed
=
model_executed
)
return
outputs
and
outputs
.
get
(
0
)
or
EngineCoreOutputs
()
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
...
...
vllm/v1/engine/input_processor.py
View file @
a810671a
...
...
@@ -24,7 +24,10 @@ from vllm.tokenizers.mistral import MistralTokenizer
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.metrics.stats
import
MultiModalCacheStats
from
vllm.v1.structured_output.backend_guidance
import
validate_guidance_grammar
from
vllm.v1.structured_output.backend_guidance
import
(
has_guidance_unsupported_json_features
,
validate_guidance_grammar
,
)
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
validate_structured_output_request_lm_format_enforcer
,
)
...
...
@@ -340,8 +343,22 @@ class InputProcessor:
# The request either failed validation
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
):
# Check if schema has features unsupported by guidance
so_params
=
params
.
structured_outputs
skip_guidance
=
False
if
so_params
.
json
:
if
isinstance
(
so_params
.
json
,
str
):
import
json
schema
=
json
.
loads
(
so_params
.
json
)
else
:
schema
=
so_params
.
json
skip_guidance
=
has_guidance_unsupported_json_features
(
schema
)
if
isinstance
(
self
.
tokenizer
,
MistralTokenizer
)
or
skip_guidance
:
# Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines
(
params
)
params
.
structured_outputs
.
_backend
=
"outlines"
else
:
...
...
vllm/v1/engine/output_processor.py
View file @
a810671a
...
...
@@ -8,6 +8,7 @@ from typing import Any, cast
import
torch
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
(
CompletionOutput
,
PoolingOutput
,
...
...
@@ -93,7 +94,7 @@ class RequestState:
request_id
:
str
,
parent_req
:
ParentRequest
|
None
,
request_index
:
int
,
lora_
name
:
st
r
|
None
,
lora_
request
:
LoRAReque
st
|
None
,
output_kind
:
RequestOutputKind
,
prompt
:
str
|
None
,
prompt_token_ids
:
list
[
int
]
|
None
,
...
...
@@ -112,7 +113,8 @@ class RequestState:
self
.
request_id
=
request_id
self
.
parent_req
=
parent_req
self
.
request_index
=
request_index
self
.
lora_name
=
lora_name
self
.
lora_request
=
lora_request
self
.
lora_name
=
lora_request
.
lora_name
if
lora_request
is
not
None
else
None
self
.
output_kind
=
output_kind
self
.
prompt
=
prompt
self
.
prompt_token_ids
=
prompt_token_ids
...
...
@@ -178,9 +180,7 @@ class RequestState:
request_id
=
request
.
request_id
,
parent_req
=
parent_req
,
request_index
=
request_index
,
lora_name
=
(
request
.
lora_request
.
name
if
request
.
lora_request
is
not
None
else
None
),
lora_request
=
request
.
lora_request
,
output_kind
=
output_kind
,
prompt
=
prompt
,
prompt_token_ids
=
request
.
prompt_token_ids
,
...
...
@@ -289,6 +289,7 @@ class RequestState:
return
RequestOutput
(
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
...
...
vllm/v1/metrics/loggers.py
View file @
a810671a
...
...
@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
from
vllm.logger
import
init_logger
from
vllm.plugins
import
STAT_LOGGER_PLUGINS_GROUP
,
load_plugins_by_group
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.perf
import
PerfMetricsLogging
from
vllm.v1.metrics.prometheus
import
unregister_vllm_metrics
from
vllm.v1.metrics.stats
import
(
CachingMetrics
,
...
...
@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
engine_is_idle
=
False
self
.
aggregated
=
False
if
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
=
PerfMetricsLogging
(
vllm_config
)
def
_reset
(
self
,
now
):
self
.
last_log_time
=
now
...
...
@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase):
self
.
num_corrupted_reqs
:
int
=
0
self
.
num_preemptions
:
int
=
0
def
_enable_perf_stats
(
self
)
->
bool
:
return
self
.
vllm_config
.
observability_config
.
enable_mfu_metrics
def
_track_iteration_stats
(
self
,
iteration_stats
:
IterationStats
):
# Save tracked stats for token counters.
self
.
num_prompt_tokens
+=
iteration_stats
.
num_prompt_tokens
...
...
@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
cudagraph_logging
.
observe
(
scheduler_stats
.
cudagraph_stats
)
if
not
self
.
aggregated
:
self
.
last_scheduler_stats
=
scheduler_stats
if
(
perf_stats
:
=
scheduler_stats
.
perf_stats
)
and
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
.
observe
(
perf_stats
)
if
mm_cache_stats
:
self
.
mm_caching_metrics
.
observe
(
mm_cache_stats
)
...
...
@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase):
"Running: %d reqs"
,
"Waiting: %d reqs"
,
]
log_args
=
[
log_args
:
list
[
int
|
float
|
str
]
=
[
self
.
last_prompt_throughput
,
self
.
last_generation_throughput
,
self
.
last_scheduler_stats
.
num_running_reqs
,
...
...
@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
kv_connector_logging
.
log
(
log_fn
=
log_fn
)
if
self
.
cudagraph_logging
is
not
None
:
self
.
cudagraph_logging
.
log
(
log_fn
=
log_fn
)
if
self
.
_enable_perf_stats
():
self
.
perf_metrics_logging
.
log
(
log_fn
=
log_fn
,
log_prefix
=
self
.
log_prefix
)
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
...
@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
def
log_prefix
(
self
):
return
"{} Engines Aggregated: "
.
format
(
len
(
self
.
engine_indexes
))
def
_enable_perf_stats
(
self
)
->
bool
:
# Adding per_gpu perf stats across engines can lead to misleading numbers.
return
False
def
record
(
self
,
scheduler_stats
:
SchedulerStats
|
None
,
...
...
vllm/v1/metrics/perf.py
0 → 100644
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Analytic flops/memory estimation module for transformer components,
to help derive MFU (Model Flops Utilization) stats for a running model.
"""
import
json
import
time
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Any
,
Protocol
import
torch
from
pydantic
import
BaseModel
,
Field
,
ValidationError
,
model_validator
from
typing_extensions
import
Self
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
,
get_kv_cache_torch_dtype
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
class
InvalidComponent
(
Exception
):
"""
Custom exception to indicate that a certain ComponentMetric is not
applicable to the given VllmConfig.
"""
pass
#### Basic Data Types ####
@
dataclass
class
DebugPerfStats
:
## Stats for debugging the metrics calculation
calc_duration
:
float
=
0.0
# time spent calculating these stats
num_prefill_requests
:
int
=
0
num_decode_requests
:
int
=
0
context_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_flops_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_read_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
num_write_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
|
None
=
None
@
dataclass
class
PerfStats
:
num_flops_per_gpu
:
int
=
0
num_read_bytes_per_gpu
:
int
=
0
num_write_bytes_per_gpu
:
int
=
0
debug_stats
:
DebugPerfStats
|
None
=
None
@
dataclass
class
ExecutionContext
:
"""
Represents an execution context for a batch of requests.
This class aggregates statistics across multiple requests in a batch,
separately tracking prefill and decode phases.
Example)
- Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context):
ctx = ExecutionContext()
ctx.add(2048, 2048, is_prefill=True)
ctx.add(1, 8192, is_prefill=False)
"""
# Prefill phase statistics
num_prefill_requests
:
int
=
0
prefill_num_tokens
:
int
=
0
# sum of num_tokens for prefill requests
prefill_context_len
:
int
=
0
# sum of context_len for prefill requests
prefill_token_context_product
:
int
=
0
# sum of (num_tokens * context_len)
# Decode phase statistics
num_decode_requests
:
int
=
0
decode_num_tokens
:
int
=
0
# sum of num_tokens for decode requests
decode_context_len
:
int
=
0
# sum of context_len for decode requests
decode_token_context_product
:
int
=
0
# sum of (num_tokens * context_len)
def
add
(
self
,
num_tokens
:
int
,
context_len
:
int
,
is_prefill
:
bool
)
->
None
:
"""Add a single request's statistics to this batch context."""
if
is_prefill
:
self
.
num_prefill_requests
+=
1
self
.
prefill_num_tokens
+=
num_tokens
self
.
prefill_context_len
+=
context_len
self
.
prefill_token_context_product
+=
num_tokens
*
context_len
else
:
self
.
num_decode_requests
+=
1
self
.
decode_num_tokens
+=
num_tokens
self
.
decode_context_len
+=
context_len
self
.
decode_token_context_product
+=
num_tokens
*
context_len
def
total_num_tokens
(
self
)
->
int
:
"""Total number of tokens across all requests in the batch."""
return
self
.
prefill_num_tokens
+
self
.
decode_num_tokens
def
total_token_context_product
(
self
)
->
int
:
"""Total sum of (num_tokens * context_len) across all requests."""
return
self
.
prefill_token_context_product
+
self
.
decode_token_context_product
@
classmethod
def
from_single_request
(
cls
,
num_tokens
:
int
,
context_len
:
int
,
is_prefill
:
bool
)
->
"ExecutionContext"
:
"""Create an ExecutionContext from a single request.
This is a convenience method primarily for testing.
"""
ctx
=
cls
()
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
)
return
ctx
class
ParsedArgs
:
"""
Syntactic sugar so that Parsers can use dot notations
to access/update the parsed arguments.
e.g.)
args = ParsedArgs()
args.x = 3
args.y = args.x + 1
"""
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
raise
AttributeError
(
f
"'
{
type
(
self
).
__name__
}
' has no attribute '
{
name
}
'"
)
def
__setattr__
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
object
.
__setattr__
(
self
,
name
,
value
)
def
model_dump
(
self
)
->
dict
[
str
,
Any
]:
return
vars
(
self
).
copy
()
#### Abstract ####
class
Parser
(
Protocol
):
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
"""
Parse the vllm config and update the current ParsedArgs and pass it on.
If the parser isn't applicable to the vllm_config, it will do nothing.
"""
...
class
ParserChain
:
"""
Applies chain of parser in a sequential order.
Later parsers might overwrite results from previous parsers,
so parsers should be chained in the appropriate order if they
are not mutually exclusive.
"""
def
__init__
(
self
,
*
parsers
:
Parser
)
->
None
:
self
.
parsers
=
list
(
parsers
)
def
add_parser
(
self
,
parser
:
Parser
)
->
None
:
self
.
parsers
.
append
(
parser
)
def
parse
(
self
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
args
=
ParsedArgs
()
for
parser
in
self
.
parsers
:
args
=
parser
.
parse
(
args
,
vllm_config
)
return
args
_COMPONENT_METRICS_REGISTRY
:
dict
[
str
,
type
[
"ComponentMetrics"
]]
=
{}
class
ComponentMetrics
(
BaseModel
,
ABC
):
"""
Each concrete ComponentMetrics class is associated with:
- fields that are required for metric derivation
(fields are specified/validated through pydantic model)
- parser to parse VllmConfig into fields
- metric methods that derive flops/bytes for a given execution context
"""
@
classmethod
@
abstractmethod
def
component_type
(
cls
)
->
str
:
...
@
classmethod
@
abstractmethod
def
get_parser
(
cls
)
->
ParserChain
:
"""
Return a ParserChain that provides values for all required fields.
The returned parser chain must populate ParsedArgs with values for every
field defined on this ComponentMetrics class. Missing fields will cause
a ValidationError when from_vllm_config() is called.
See individual Parser docstrings for which args they provide, and field
comments on ComponentMetrics subclasses for which parser provides each field.
"""
...
def
__init_subclass__
(
cls
):
_COMPONENT_METRICS_REGISTRY
[
cls
.
component_type
()]
=
cls
@
classmethod
def
from_vllm_config
(
cls
,
vllm_config
:
VllmConfig
)
->
Self
:
"""
Instantiate this class from VllmConfig.
Raises ValidationError if parsing fails.
"""
parser
=
cls
.
get_parser
()
parsed_args
=
parser
.
parse
(
vllm_config
)
try
:
return
cls
.
model_validate
(
parsed_args
.
model_dump
())
except
ValidationError
as
e
:
raise
InvalidComponent
(
f
"Invalid
{
cls
.
component_type
()
}
config:
{
e
}
"
)
from
e
@
classmethod
def
registered_metrics
(
cls
)
->
Iterable
[
type
[
"ComponentMetrics"
]]:
return
iter
(
_COMPONENT_METRICS_REGISTRY
.
values
())
@
abstractmethod
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
@
abstractmethod
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
@
abstractmethod
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
...
def
get_num_flops
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_num_flops_breakdown
(
ctx
,
per_gpu
).
values
())
def
get_read_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
).
values
())
def
get_write_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
self
.
get_write_bytes_breakdown
(
ctx
,
per_gpu
).
values
())
#### parsers ####
class
BaseConfigParser
(
Parser
):
"""
Parses base model configuration.
Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers,
weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
model_config
=
vllm_config
.
model_config
args
.
vocab_size
=
model_config
.
get_vocab_size
()
args
.
hidden_size
=
model_config
.
get_hidden_size
()
# NOTE: model_config.get_attention_heads() divide by TP
# so we access field manually here to get total num_heads
args
.
num_attention_heads
=
get_required
(
model_config
.
hf_text_config
,
"num_attention_heads"
)
args
.
num_hidden_layers
=
get_required
(
model_config
.
hf_text_config
,
"num_hidden_layers"
)
model_dtype
=
vllm_config
.
model_config
.
dtype
if
isinstance
(
model_dtype
,
torch
.
dtype
):
torch_dtype
=
model_dtype
elif
isinstance
(
model_dtype
,
str
)
and
model_dtype
in
STR_DTYPE_TO_TORCH_DTYPE
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
else
:
# FIXME: handle this better
logger
.
warning
(
"Unknown model_dtype %s, defaulting to bfloat16"
,
model_dtype
,
)
torch_dtype
=
torch
.
bfloat16
args
.
weight_byte_size
=
get_dtype_size
(
torch_dtype
)
# FIXME: handle this better by parsing whether activations use
# bf16, fp32, etc...
args
.
activation_byte_size
=
2
args
.
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
args
.
tp_size
=
vllm_config
.
parallel_config
.
tensor_parallel_size
args
.
pp_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
args
.
enable_ep
=
vllm_config
.
parallel_config
.
enable_expert_parallel
return
args
#### Attention ####
class
BaseAttentionConfigParser
(
Parser
):
"""
Parses attention-specific configuration.
Provides: num_key_value_heads, head_dim, cache_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
model_config
=
vllm_config
.
model_config
args
.
num_key_value_heads
=
model_config
.
get_total_num_kv_heads
()
args
.
head_dim
=
model_config
.
get_head_size
()
model_dtype
=
vllm_config
.
model_config
.
dtype
cache_dtype
=
vllm_config
.
cache_config
.
cache_dtype
kv_cache_torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
args
.
cache_byte_size
=
get_dtype_size
(
kv_cache_torch_dtype
)
return
args
class
AttentionQuantizationConfigParser
(
Parser
):
"""
Parses quantization configuration for attention layers.
Overrides: weight_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
quant_config
if
cfg
is
None
:
return
args
quant_method
=
cfg
.
get_name
()
if
quant_method
in
[
"fp8"
,
"fbgemm_fp8"
]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args
.
weight_byte_size
=
1
elif
quant_method
==
"mxfp4"
:
# FIXME: Also has "ignored layers" issue above
args
.
weight_byte_size
=
0.5
else
:
# FIXME: Add more parsing logic for different quant methods.
raise
InvalidComponent
return
args
class
AttentionMetrics
(
ComponentMetrics
):
# From BaseConfigParser
num_hidden_layers
:
int
=
Field
(...,
gt
=
0
)
hidden_size
:
int
=
Field
(...,
gt
=
0
)
num_attention_heads
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
tp_size
:
int
=
Field
(...,
gt
=
0
)
pp_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseAttentionConfigParser
num_key_value_heads
:
int
=
Field
(...,
gt
=
0
)
head_dim
:
int
=
Field
(...,
gt
=
0
)
cache_byte_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseConfig Parser, overridden by AttentionQuantizationConfigParser
weight_byte_size
:
int
|
float
=
Field
(...,
gt
=
0
)
# TODO: discern cases where we have mixture of different attention layer types
# such as SWA, MLA, etc.
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"attn"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
BaseAttentionConfigParser
(),
AttentionQuantizationConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
TC
=
ctx
.
total_token_context_product
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
return
{
"qkv_proj"
:
2
*
T
*
D
*
(
q
+
2
*
kv
)
*
d
*
L
,
"attn_qk"
:
2
*
q
*
TC
*
d
*
L
,
"attn_av"
:
2
*
q
*
TC
*
d
*
L
,
"out_proj"
:
2
*
T
*
D
*
q
*
d
*
L
,
}
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
read_bytes
=
{}
read_bytes
[
"qkv_input"
]
=
T
*
D
*
self
.
activation_byte_size
*
L
read_bytes
[
"qkv_weight"
]
=
int
(
D
*
(
q
+
2
*
kv
)
*
d
*
self
.
weight_byte_size
*
L
)
# Attention input reads differ between prefill and decode
# Prefill: read Q, K, V activations (all in activation_byte_size)
if
ctx
.
prefill_num_tokens
>
0
:
read_bytes
[
"attn_input"
]
=
(
(
ctx
.
prefill_num_tokens
*
q
+
2
*
ctx
.
prefill_context_len
*
kv
)
*
d
*
self
.
activation_byte_size
*
L
)
# Decode: read Q activations + read K, V from cache (in cache_byte_size)
if
ctx
.
decode_num_tokens
>
0
:
read_bytes
[
"attn_input"
]
=
read_bytes
.
get
(
"attn_input"
,
0
)
+
(
ctx
.
decode_num_tokens
*
q
*
d
*
self
.
activation_byte_size
*
L
+
2
*
ctx
.
decode_context_len
*
kv
*
d
*
self
.
cache_byte_size
*
L
)
read_bytes
[
"out_input"
]
=
T
*
q
*
d
*
self
.
activation_byte_size
*
L
read_bytes
[
"out_weight"
]
=
int
(
q
*
d
*
D
*
self
.
weight_byte_size
*
L
)
return
read_bytes
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for attention layers."""
L
,
D
,
q
,
kv
,
d
=
(
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
num_key_value_heads
,
self
.
head_dim
,
)
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
L
//=
self
.
pp_size
# tensor parallel along heads
q
=
max
(
1
,
q
//
self
.
tp_size
)
kv
=
max
(
1
,
kv
//
self
.
tp_size
)
return
{
"qkv_output"
:
T
*
(
q
+
2
*
kv
)
*
d
*
self
.
activation_byte_size
*
L
,
"kv_cache"
:
2
*
T
*
kv
*
d
*
self
.
cache_byte_size
*
L
,
"out_output"
:
T
*
D
*
self
.
activation_byte_size
*
L
,
}
#### Ffn ####
class
BaseFfnConfigParser
(
Parser
):
"""
Parses FFN and MoE configuration.
Provides: intermediate_size, num_experts, num_experts_per_tok,
moe_intermediate_size, num_shared_experts, num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
args
.
intermediate_size
=
getattr
(
cfg
,
"intermediate_size"
,
args
.
hidden_size
*
4
)
# Try different naming conventions.
args
.
num_experts
=
vllm_config
.
model_config
.
get_num_experts
()
args
.
num_experts_per_tok
=
getattr_from_list
(
cfg
,
[
"num_experts_per_tok"
,
"moe_topk"
],
0
)
args
.
moe_intermediate_size
=
getattr_from_list
(
cfg
,
[
"moe_intermediate_size"
,
"intermediate_size"
],
0
)
args
.
num_shared_experts
=
getattr_from_list
(
cfg
,
[
"n_shared_experts"
,
"num_shared_experts"
],
0
)
is_moe
=
args
.
num_experts
!=
0
# Assume all MoE layers by default
args
.
num_moe_layers
=
args
.
num_hidden_layers
if
is_moe
else
0
return
args
class
FfnParallelParser
(
Parser
):
"""
Parses FFN parallelism configuration.
Provides: ffn_tp_size, ffn_ep_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
# NOTE: ffn tp_size does not equal the tp_size parameter directly.
# e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.)
if
args
.
enable_ep
:
ffn_tp_size
,
ffn_ep_size
=
1
,
args
.
dp_size
*
args
.
tp_size
else
:
ffn_tp_size
,
ffn_ep_size
=
args
.
dp_size
*
args
.
tp_size
,
1
args
.
ffn_tp_size
=
ffn_tp_size
args
.
ffn_ep_size
=
ffn_ep_size
return
args
class
InterleaveMoeLayerStepParser
(
Parser
):
"""
Parses interleave_moe_layer_step field for models like Llama4.
Overrides: num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
if
(
hasattr
(
cfg
,
"interleave_moe_layer_step"
)
and
cfg
.
interleave_moe_layer_step
>
0
):
args
.
num_moe_layers
=
len
(
[
layer
for
layer
in
range
(
args
.
num_hidden_layers
)
if
(
layer
+
1
)
%
cfg
.
interleave_moe_layer_step
==
0
]
)
return
args
class
MoeLayerFreqParser
(
Parser
):
"""
Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek.
Overrides: num_moe_layers
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
model_config
.
hf_config
if
hasattr
(
cfg
,
"text_config"
)
and
cfg
.
text_config
is
not
None
:
cfg
=
cfg
.
text_config
if
hasattr
(
cfg
,
"moe_layer_freq"
)
and
hasattr
(
cfg
,
"first_k_dense_replace"
):
args
.
num_moe_layers
=
len
(
[
layer
for
layer
in
range
(
args
.
num_hidden_layers
)
if
layer
>=
cfg
.
first_k_dense_replace
and
layer
%
cfg
.
moe_layer_freq
==
0
]
)
return
args
class
FfnQuantizationConfigParser
(
Parser
):
"""
Parses quantization configuration for FFN layers.
Overrides: weight_byte_size
"""
def
parse
(
self
,
args
:
ParsedArgs
,
vllm_config
:
VllmConfig
)
->
ParsedArgs
:
cfg
=
vllm_config
.
quant_config
if
cfg
is
None
:
return
args
quant_method
=
cfg
.
get_name
()
if
quant_method
in
[
"fp8"
,
"fbgemm_fp8"
]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# (there might be more quantization methods for fp8).
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args
.
weight_byte_size
=
1
pass
elif
quant_method
==
"mxfp4"
:
# FIXME: Also has "ignored layers" issue above
args
.
weight_byte_size
=
0.5
else
:
# FIXME: Add more parsing logic for different quant methods.
raise
InvalidComponent
return
args
class
FfnMetrics
(
ComponentMetrics
):
# From BaseConfigParser
num_hidden_layers
:
int
=
Field
(...,
gt
=
0
)
hidden_size
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
pp_size
:
int
=
Field
(...,
gt
=
0
)
# From FfnParallelParser
ffn_tp_size
:
int
=
Field
(...,
gt
=
0
)
ffn_ep_size
:
int
=
Field
(...,
gt
=
0
)
# From BaseFfnConfigParser
intermediate_size
:
int
=
Field
(...,
gt
=
0
)
num_experts
:
int
=
Field
(
0
)
num_experts_per_tok
:
int
=
Field
(
1
)
moe_intermediate_size
:
int
=
Field
(
0
)
num_shared_experts
:
int
=
Field
(
0
)
# From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq
num_moe_layers
:
int
=
Field
(...,
ge
=
0
)
# FIXME: might have to make this more granular
# (i.e. dense_weight_byte_size, moe_routed_weight_byte_size,
# moe_shared_weight_byte_size)
# since it can differ from byte size of other components (e.g. attn)
# and can differ even from each other.
# From BaseConfigParser, can be overridden by FfnQuantizationConfigParser
weight_byte_size
:
int
|
float
=
Field
(...,
gt
=
0
)
@
model_validator
(
mode
=
"after"
)
def
validate_moe_fields
(
self
)
->
Self
:
"""Validate that MoE-related fields are properly set when num_moe_layers > 0."""
if
self
.
num_moe_layers
>
0
:
assert
self
.
num_experts
,
f
"
{
self
.
num_experts
=
}
"
assert
self
.
num_experts_per_tok
,
f
"
{
self
.
num_experts_per_tok
=
}
"
assert
self
.
moe_intermediate_size
,
f
"
{
self
.
moe_intermediate_size
=
}
"
return
self
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"ffn"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
FfnParallelParser
(),
BaseFfnConfigParser
(),
InterleaveMoeLayerStepParser
(),
MoeLayerFreqParser
(),
FfnQuantizationConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate flops breakdown for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
flops
=
{}
# Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down)
if
Ld
:
flops
[
"dense_ffn"
]
=
2
*
D
*
3
*
DI
*
T
*
Ld
# MoE routed experts (each token activates E experts)
if
Lm
and
E
:
flops
[
"routed_ffn"
]
=
2
*
D
*
3
*
MI
*
num_activated_tokens
*
Lm
# MoE shared experts (all S shared experts run for every token)
if
Lm
and
S
:
flops
[
"shared_ffn"
]
=
2
*
D
*
3
*
MI
*
S
*
T
*
Lm
return
flops
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate read memory traffic for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
num_experts
=
self
.
num_experts
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
if
num_experts
is
not
None
:
num_experts
//=
self
.
ffn_ep_size
read_bytes
=
{}
# Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation)
if
Ld
:
read_bytes
[
"dense_up_gate_input"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_up_gate_weights"
]
=
int
(
2
*
D
*
DI
*
self
.
weight_byte_size
*
Ld
)
read_bytes
[
"dense_silu_input"
]
=
int
(
2
*
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_down_input"
]
=
int
(
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
read_bytes
[
"dense_down_weights"
]
=
int
(
D
*
DI
*
self
.
weight_byte_size
*
Ld
)
if
Lm
:
# MoE routed expert reads
if
E
:
# FIXME: Assume perfect load balancing for now.
num_activated_experts
=
min
(
num_activated_tokens
,
num_experts
)
read_bytes
[
"routed_up_gate_input"
]
=
int
(
num_activated_tokens
*
D
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_up_gate_weights"
]
=
int
(
2
*
D
*
MI
*
num_activated_experts
*
self
.
weight_byte_size
*
Lm
)
read_bytes
[
"routed_silu_input"
]
=
int
(
2
*
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_down_input"
]
=
int
(
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"routed_down_weights"
]
=
int
(
D
*
MI
*
num_activated_experts
*
self
.
weight_byte_size
*
Lm
)
# MoE shared expert reads
if
S
:
read_bytes
[
"shared_up_gate_input"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_up_gate_weights"
]
=
int
(
2
*
D
*
MI
*
S
*
self
.
weight_byte_size
*
Lm
)
read_bytes
[
"shared_silu_input"
]
=
int
(
2
*
T
*
MI
*
S
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_down_input"
]
=
int
(
T
*
MI
*
self
.
activation_byte_size
*
Lm
)
read_bytes
[
"shared_down_weights"
]
=
int
(
D
*
MI
*
S
*
self
.
weight_byte_size
*
Lm
)
return
read_bytes
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for FFN layers."""
L
,
D
,
DI
=
self
.
num_hidden_layers
,
self
.
hidden_size
,
self
.
intermediate_size
Lm
,
E
,
MI
,
S
=
(
self
.
num_moe_layers
,
self
.
num_experts_per_tok
,
self
.
moe_intermediate_size
,
self
.
num_shared_experts
,
)
T
=
ctx
.
total_num_tokens
()
Ld
=
L
-
Lm
num_activated_tokens
=
T
*
E
if
E
else
0
if
per_gpu
:
Ld
//=
self
.
pp_size
Lm
//=
self
.
pp_size
DI
//=
self
.
ffn_tp_size
if
MI
is
not
None
:
MI
//=
self
.
ffn_tp_size
if
E
:
num_activated_tokens
//=
self
.
ffn_ep_size
write_bytes
=
{}
# Dense FFN layers
if
Ld
:
write_bytes
[
"dense_up_gate_output"
]
=
int
(
2
*
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
write_bytes
[
"dense_silu_output"
]
=
int
(
T
*
DI
*
self
.
activation_byte_size
*
Ld
)
write_bytes
[
"dense_down_output"
]
=
int
(
T
*
D
*
self
.
activation_byte_size
*
Ld
)
# MoE outputs
if
Lm
:
if
E
:
write_bytes
[
"routed_up_gate_output"
]
=
int
(
2
*
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"routed_silu_output"
]
=
int
(
num_activated_tokens
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"routed_down_output"
]
=
int
(
num_activated_tokens
*
D
*
self
.
activation_byte_size
*
Lm
)
if
S
:
write_bytes
[
"shared_up_gate_output"
]
=
int
(
2
*
T
*
S
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"shared_silu_output"
]
=
int
(
T
*
S
*
MI
*
self
.
activation_byte_size
*
Lm
)
write_bytes
[
"shared_down_output"
]
=
int
(
T
*
S
*
D
*
self
.
activation_byte_size
*
Lm
)
return
write_bytes
#### Unembed ####
class
UnembedMetrics
(
ComponentMetrics
):
# From BaseConfigParser
hidden_size
:
int
=
Field
(...,
gt
=
0
)
vocab_size
:
int
=
Field
(...,
gt
=
0
)
weight_byte_size
:
int
=
Field
(...,
gt
=
0
)
activation_byte_size
:
int
=
Field
(...,
gt
=
0
)
tp_size
:
int
@
classmethod
def
component_type
(
cls
)
->
str
:
return
"unembed"
@
classmethod
def
get_parser
(
cls
)
->
ParserChain
:
return
ParserChain
(
BaseConfigParser
(),
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate flops breakdown for unembedding layer."""
D
,
V
=
self
.
hidden_size
,
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"unembed"
:
2
*
T
*
D
*
V
,
}
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate read memory traffic for unembedding layer."""
D
,
V
=
self
.
hidden_size
,
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"input"
:
T
*
D
*
self
.
activation_byte_size
,
"weight"
:
D
*
V
*
self
.
weight_byte_size
,
}
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
"""Calculate write memory traffic for unembedding layer."""
V
=
self
.
vocab_size
T
=
ctx
.
total_num_tokens
()
if
per_gpu
:
V
//=
self
.
tp_size
return
{
"output"
:
T
*
V
*
self
.
activation_byte_size
,
}
#### ModelMetrics ####
class
ModelMetrics
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
"""
Parse vllm_config to instantiate metrics for each component.
is_enabled() will return False if no component metrics could be instantiated.
"""
self
.
vllm_config
=
vllm_config
self
.
metrics
:
list
[
ComponentMetrics
]
=
[]
for
metric_cls
in
ComponentMetrics
.
registered_metrics
():
try
:
metric
=
metric_cls
.
from_vllm_config
(
vllm_config
)
self
.
metrics
.
append
(
metric
)
logger
.
info
(
"Instantiated ComponentMetrics [%s] with (%s)"
,
metric
.
component_type
(),
str
(
metric
),
)
except
InvalidComponent
as
e
:
logger
.
debug
(
"Failed to instantiate %s from %s"
,
metric_cls
.
component_type
(),
str
(
e
),
)
def
is_enabled
(
self
)
->
bool
:
return
len
(
self
.
metrics
)
>
0
def
get_num_flops
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_num_flops
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_read_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_read_bytes
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_write_bytes
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
int
:
return
sum
(
metric
.
get_write_bytes
(
ctx
,
per_gpu
)
for
metric
in
self
.
metrics
)
def
get_num_flops_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_num_flops_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_read_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_read_bytes_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_write_bytes_breakdown
(
self
,
ctx
:
ExecutionContext
,
per_gpu
:
bool
=
True
)
->
dict
[
str
,
int
]:
total
=
{}
for
metric
in
self
.
metrics
:
breakdown
=
metric
.
get_write_bytes_breakdown
(
ctx
,
per_gpu
)
component
=
metric
.
component_type
()
prefixed
=
{
f
"
{
component
}
.
{
key
}
"
:
val
for
key
,
val
in
breakdown
.
items
()}
total
.
update
(
prefixed
)
return
total
def
get_step_perf_stats_per_gpu
(
self
,
scheduler_output
:
SchedulerOutput
)
->
PerfStats
:
"""
Calculate perf stats for the current step based on scheduled tokens.
"""
t0
=
time
.
monotonic
()
# Build a single batch context
ctx
=
ExecutionContext
()
# Process new requests (these are in prefill phase)
for
new_req
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req
.
req_id
num_tokens
=
scheduler_output
.
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens
==
0
:
continue
# For new requests, context_len = num_computed_tokens + num_tokens
# num_computed_tokens represents previously computed tokens in the sequence
context_len
=
new_req
.
num_computed_tokens
+
num_tokens
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
=
True
)
# Process cached requests (continuing requests)
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens
==
0
:
continue
# For cached requests, we have the current num_computed_tokens
num_computed_tokens
=
cached_reqs
.
num_computed_tokens
[
i
]
context_len
=
num_computed_tokens
+
num_tokens
# Cached requests are typically in decode phase (num_tokens == 1)
# unless they're doing chunked prefill (num_tokens > 1)
is_prefill
=
num_tokens
>
1
ctx
.
add
(
num_tokens
,
context_len
,
is_prefill
)
num_flops_breakdown
=
self
.
get_num_flops_breakdown
(
ctx
,
True
)
read_bytes_breakdown
=
self
.
get_read_bytes_breakdown
(
ctx
,
True
)
write_bytes_breakdown
=
self
.
get_write_bytes_breakdown
(
ctx
,
True
)
perf_stats
=
PerfStats
(
sum
(
num_flops_breakdown
.
values
()),
sum
(
read_bytes_breakdown
.
values
()),
sum
(
write_bytes_breakdown
.
values
()),
)
if
envs
.
VLLM_DEBUG_MFU_METRICS
:
perf_stats
.
debug_stats
=
DebugPerfStats
(
time
.
monotonic
()
-
t0
,
ctx
.
num_prefill_requests
,
ctx
.
num_decode_requests
,
asdict
(
ctx
),
num_flops_breakdown
,
read_bytes_breakdown
,
write_bytes_breakdown
,
)
return
perf_stats
#### Logging ####
class
PerfMetricsDebugLogging
:
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
total_calc_duration
:
float
=
0.0
self
.
total_num_prefill_requests
:
int
=
0
self
.
total_num_decode_requests
:
int
=
0
self
.
total_num_batches
:
int
=
0
self
.
total_context_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_num_flops_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_read_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
self
.
total_write_bytes_per_gpu_breakdown
:
dict
[
str
,
int
]
=
{}
def
observe
(
self
,
debug_stats
:
DebugPerfStats
)
->
None
:
self
.
total_calc_duration
+=
debug_stats
.
calc_duration
self
.
total_num_prefill_requests
+=
debug_stats
.
num_prefill_requests
self
.
total_num_decode_requests
+=
debug_stats
.
num_decode_requests
self
.
total_num_batches
+=
1
for
dst
,
src
in
zip
(
[
self
.
total_context_breakdown
,
self
.
total_num_flops_per_gpu_breakdown
,
self
.
total_read_bytes_per_gpu_breakdown
,
self
.
total_write_bytes_per_gpu_breakdown
,
],
[
debug_stats
.
context_breakdown
,
debug_stats
.
num_flops_per_gpu_breakdown
,
debug_stats
.
num_read_bytes_per_gpu_breakdown
,
debug_stats
.
num_write_bytes_per_gpu_breakdown
,
],
):
assert
isinstance
(
src
,
dict
)
for
key
,
val
in
src
.
items
():
dst
[
key
]
=
dst
.
get
(
key
,
0
)
+
val
def
log
(
self
,
log_fn
,
log_prefix
:
str
,
delta_time
:
float
):
# pretty print breakdowns
total_num_flops_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e12
:.
1
f
}
TF"
for
k
,
v
in
self
.
total_num_flops_per_gpu_breakdown
.
items
()
}
total_read_bytes_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e9
:.
1
f
}
GB"
for
k
,
v
in
self
.
total_read_bytes_per_gpu_breakdown
.
items
()
}
total_write_bytes_per_gpu_breakdown
=
{
k
:
f
"
{
v
/
1e9
:.
1
f
}
GB"
for
k
,
v
in
self
.
total_write_bytes_per_gpu_breakdown
.
items
()
}
logger
.
debug
(
"%sMFU details: %s"
,
log_prefix
,
json
.
dumps
(
{
"prefill_reqs"
:
self
.
total_num_prefill_requests
,
"decode_reqs"
:
self
.
total_num_decode_requests
,
"num_batches"
:
self
.
total_num_batches
,
"context_breakdown"
:
self
.
total_context_breakdown
,
"flops_breakdown"
:
total_num_flops_per_gpu_breakdown
,
"num_read_bytes_breakdown"
:
total_read_bytes_per_gpu_breakdown
,
"num_write_bytes_breakdown"
:
(
total_write_bytes_per_gpu_breakdown
),
"duration"
:
f
"
{
delta_time
:.
1
f
}
s"
,
"mfu_calc_overhead"
:
(
f
"
{
self
.
total_calc_duration
/
delta_time
:.
1
%
}
"
),
},
indent
=
2
,
),
)
class
PerfMetricsLogging
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
):
self
.
vllm_config
=
vllm_config
self
.
pp_size
=
vllm_config
.
parallel_config
.
pipeline_parallel_size
self
.
debug_logging
:
PerfMetricsDebugLogging
|
None
=
None
if
envs
.
VLLM_DEBUG_MFU_METRICS
:
self
.
debug_logging
=
PerfMetricsDebugLogging
()
self
.
reset
()
def
reset
(
self
):
self
.
last_log_time
=
time
.
monotonic
()
self
.
total_num_flops_per_gpu
:
int
=
0
self
.
total_read_bytes_per_gpu
:
int
=
0
self
.
total_write_bytes_per_gpu
:
int
=
0
if
self
.
debug_logging
:
self
.
debug_logging
.
reset
()
def
observe
(
self
,
perf_stats
:
PerfStats
)
->
None
:
self
.
total_num_flops_per_gpu
+=
perf_stats
.
num_flops_per_gpu
self
.
total_read_bytes_per_gpu
+=
perf_stats
.
num_read_bytes_per_gpu
self
.
total_write_bytes_per_gpu
+=
perf_stats
.
num_write_bytes_per_gpu
if
self
.
debug_logging
:
assert
perf_stats
.
debug_stats
is
not
None
self
.
debug_logging
.
observe
(
perf_stats
.
debug_stats
)
def
log
(
self
,
log_fn
=
logger
.
info
,
log_prefix
:
str
=
""
)
->
None
:
if
not
(
self
.
total_num_flops_per_gpu
or
self
.
total_read_bytes_per_gpu
or
self
.
total_write_bytes_per_gpu
):
return
now
=
time
.
monotonic
()
delta_time
=
now
-
self
.
last_log_time
if
delta_time
<=
0.0
:
avg_tflops_per_gpu
=
0.0
avg_gbps_per_gpu
=
0.0
else
:
avg_tflops_per_gpu
=
self
.
total_num_flops_per_gpu
/
delta_time
/
1e12
avg_gbps_per_gpu
=
(
(
self
.
total_read_bytes_per_gpu
+
self
.
total_write_bytes_per_gpu
)
/
delta_time
/
1e9
)
log_fn
(
"%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU"
,
log_prefix
,
avg_tflops_per_gpu
,
avg_gbps_per_gpu
,
)
if
self
.
debug_logging
:
self
.
debug_logging
.
log
(
log_fn
,
log_prefix
,
delta_time
)
self
.
reset
()
## util functions
def
get_required
(
obj
:
object
,
attr
:
str
):
"""Get an attr from an object, or throw a InvalidComponentError if it's not set."""
if
not
hasattr
(
obj
,
attr
):
raise
InvalidComponent
(
f
"Missing required attr
{
attr
}
in config"
)
return
getattr
(
obj
,
attr
)
def
getattr_from_list
(
obj
:
object
,
attrs
:
list
[
str
],
default
:
object
=
None
):
"""Try to get the first attr that exists in the object
from a list of attrs. Otherwise return None."""
for
attr
in
attrs
:
if
hasattr
(
obj
,
attr
):
return
getattr
(
obj
,
attr
)
return
default
vllm/v1/metrics/stats.py
View file @
a810671a
...
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
import
vllm.envs
as
envs
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.v1.metrics.perf
import
PerfStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
...
...
@@ -186,6 +187,8 @@ class SchedulerStats:
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
perf_stats
:
PerfStats
|
None
=
None
@
dataclass
class
RequestStateStats
:
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
a810671a
...
...
@@ -44,6 +44,32 @@ def _walk_json_for_additional_properties(data: object):
_walk_json_for_additional_properties
(
item
)
def
has_guidance_unsupported_json_features
(
schema
:
dict
[
str
,
Any
])
->
bool
:
"""Check if JSON schema contains features unsupported by guidance/llguidance."""
def
check_object
(
obj
:
dict
[
str
,
Any
])
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# patternProperties is not supported by llguidance
if
"patternProperties"
in
obj
:
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
process_for_additional_properties
(
guide_json
:
str
|
dict
[
str
,
Any
],
)
->
dict
[
str
,
Any
]:
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment