Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -114,7 +114,8 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -114,7 +114,8 @@ class Glm4MoeModelToolParser(ToolParser):
ToolCall( ToolCall(
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=tc_name, arguments=json.dumps(arg_dct) name=tc_name,
arguments=json.dumps(arg_dct, ensure_ascii=False),
), ),
) )
) )
......
...@@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -122,6 +122,8 @@ class MinimaxM2ToolParser(ToolParser):
self.streaming_request = None self.streaming_request = None
# Clear previous tool call history to avoid state pollution # Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear() self.prev_tool_call_arr.clear()
# Reset streamed args tracking
self.streamed_args_for_tool.clear()
def _extract_name(self, name_str: str) -> str: def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string.""" """Extract name from quoted string."""
...@@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -421,9 +423,12 @@ class MinimaxM2ToolParser(ToolParser):
self.prev_tool_call_arr.append( self.prev_tool_call_arr.append(
{ {
"name": self.current_function_name, "name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later "arguments": {}, # Placeholder, will be updated later
} }
) )
# Initialize streamed_args_for_tool for this tool call
if len(self.streamed_args_for_tool) <= self.current_tool_index:
self.streamed_args_for_tool.append("")
# Send header with function info # Send header with function info
return DeltaMessage( return DeltaMessage(
...@@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -445,6 +450,9 @@ class MinimaxM2ToolParser(ToolParser):
# Send opening brace if not sent yet # Send opening brace if not sent yet
if self.in_function and not self.json_started: if self.in_function and not self.json_started:
self.json_started = True self.json_started = True
# Update streamed_args_for_tool for opening brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
...@@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -493,7 +501,7 @@ class MinimaxM2ToolParser(ToolParser):
args = parsed_tool.function.arguments args = parsed_tool.function.arguments
self.prev_tool_call_arr[self.current_tool_index][ self.prev_tool_call_arr[self.current_tool_index][
"arguments" "arguments"
] = args ] = json.loads(args)
except Exception: except Exception:
pass # Ignore parsing errors during streaming pass # Ignore parsing errors during streaming
...@@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -505,7 +513,9 @@ class MinimaxM2ToolParser(ToolParser):
) )
] ]
) )
# Update streamed_args_for_tool for closing brace
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "}"
# Reset state for next tool # Reset state for next tool
self.json_closed = True self.json_closed = True
self.in_function = False self.in_function = False
...@@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -630,7 +640,11 @@ class MinimaxM2ToolParser(ToolParser):
) )
self.param_count += 1 self.param_count += 1
# Update streamed_args_for_tool for this tool call
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += (
json_fragment
)
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
......
...@@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool: ...@@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
) )
@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if not has_flashinfer_moe():
return False
required_functions = [
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache @functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool: def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available.""" """Return `True` if FlashInfer CUTLASS fused MoE is available."""
...@@ -288,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: ...@@ -288,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if force_use_trtllm_attention() is False: if force_use_trtllm_attention() is False:
return False return False
has_trtllm = supports_trtllm_attention() has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0) # num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if has_trtllm and num_kv_heads == 1:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
def use_trtllm_attention( def use_trtllm_attention(
...@@ -338,6 +366,15 @@ def use_trtllm_attention( ...@@ -338,6 +366,15 @@ def use_trtllm_attention(
) )
return False return False
# num_kv_heads=1 is not supported
if num_kv_heads == 1:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return False
if has_spec and not is_prefill: if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes # Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
......
...@@ -66,27 +66,43 @@ class MemorySnapshot: ...@@ -66,27 +66,43 @@ class MemorySnapshot:
torch_memory: int = 0 torch_memory: int = 0
non_torch_memory: int = 0 non_torch_memory: int = 0
timestamp: float = 0.0 timestamp: float = 0.0
device: torch.types.Device = None
auto_measure: bool = True auto_measure: bool = True
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.device is None:
from vllm.platforms import current_platform
device_fn = current_platform.current_device
assert device_fn is not None
self.device_ = torch.device(device_fn())
else:
self.device_ = torch.device(self.device)
if self.auto_measure: if self.auto_measure:
self.measure() self.measure()
def measure(self) -> None: def measure(self) -> None:
from vllm.platforms import current_platform from vllm.platforms import current_platform
device = self.device_
# we measure the torch peak memory usage via allocated_bytes, # we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` . # rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`, # After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink # `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens. # when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) self.torch_peak = torch.cuda.memory_stats(device).get(
"allocated_bytes.all.peak", 0
)
self.free_memory, self.total_memory = torch.cuda.mem_get_info() self.free_memory, self.total_memory = torch.cuda.mem_get_info(device)
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms and current_platform.get_device_capability(device.index)
in shared_sysmem_device_mem_sms
): ):
# On UMA (Orin, Thor and Spark) platform, # On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory, # where both CPU and GPU rely on system memory,
...@@ -106,12 +122,18 @@ class MemorySnapshot: ...@@ -106,12 +122,18 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes # torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.) # PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage # this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved() self.torch_memory = torch.cuda.memory_reserved(device)
self.non_torch_memory = self.cuda_memory - self.torch_memory self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time() self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
if self.device_ != other.device_:
raise ValueError(
"The two snapshots should be from the same device! "
f"Found: {self.device_} vs. {other.device_}"
)
return MemorySnapshot( return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak, torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory, free_memory=self.free_memory - other.free_memory,
...@@ -120,6 +142,7 @@ class MemorySnapshot: ...@@ -120,6 +142,7 @@ class MemorySnapshot:
torch_memory=self.torch_memory - other.torch_memory, torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory, non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp, timestamp=self.timestamp - other.timestamp,
device=self.device_,
auto_measure=False, auto_measure=False,
) )
......
...@@ -24,6 +24,10 @@ else: ...@@ -24,6 +24,10 @@ else:
ModelConfig = object ModelConfig = object
IntermediateTensors = object IntermediateTensors = object
import logging
logger = logging.getLogger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32, "float32": torch.float32,
...@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = { ...@@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
} }
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8": "fp8_e4m3",
}
T = TypeVar("T") T = TypeVar("T")
...@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype( ...@@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
return torch_dtype return torch_dtype
def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
"""Get the KV cache quantization algorithm string from the quantization config.
Maps various FP8 format names to vLLM's standard cache dtype strings.
Returns None if no kv_cache_quant_algo is specified.
Returns "auto" if the value is not recognized/supported.
"""
# Mapping from model config values to vLLM cache_dtype strings
quant_method = quant_cfg.get("quant_method", "")
if quant_method.startswith("modelopt"):
quantization_inner = quant_cfg.get("quantization", quant_cfg)
# Check if quant config is specified and use kv cache quant algo
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
"kv_cache_quant_algo"
)
if isinstance(kv_algo, str):
kv_algo_lower = kv_algo.lower()
# Try to map to vLLM's standard format
if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP:
return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower]
else:
# Unknown/unsupported format - return "auto" as safe fallback
logger.warning(
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
"config. Supported values: %s. Falling back to 'auto'.",
kv_algo,
list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()),
)
return "auto"
return None
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
"""Get the KV cache quantization algorithm dtype from the quantization config."""
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None and kv_algo_str != "auto":
# Only convert if we have a valid dtype string (not "auto" fallback)
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str]
return None
def resolve_kv_cache_dtype_string(
kv_cache_dtype: str, model_config: ModelConfig
) -> str:
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
Returns the resolved cache_dtype string.
"""
if kv_cache_dtype != "auto":
return kv_cache_dtype
hf_cfg = getattr(model_config, "hf_config", None)
if hf_cfg is not None:
quant_cfg = getattr(hf_cfg, "quantization_config", None)
if quant_cfg is not None:
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None:
return kv_algo_str
# Default to auto (will be handled by downstream code)
return "auto"
def kv_cache_dtype_str_to_dtype( def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype: ) -> torch.dtype:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar
...@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -250,6 +251,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
if get_flash_attn_version() == 3 if get_flash_attn_version() == 3
else AttentionCGSupport.UNIFORM_BATCH else AttentionCGSupport.UNIFORM_BATCH
) )
supports_update_block_table: bool = True
def __init__( def __init__(
self, self,
...@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -493,6 +495,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
) )
return attn_metadata return attn_metadata
def update_block_table(
self,
metadata: FlashAttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> FlashAttentionMetadata:
new_metadata = copy.copy(metadata)
new_metadata.block_table = blk_table
new_metadata.slot_mapping = slot_mapping
return new_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)
......
...@@ -16,6 +16,7 @@ from flashinfer import ( ...@@ -16,6 +16,7 @@ from flashinfer import (
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
...@@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -59,6 +60,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.utils import CpuGpuBuffer
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
...@@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper: ...@@ -181,7 +183,6 @@ class BatchDCPPrefillWrapper:
paged_kv_indptr_cpu: torch.Tensor, paged_kv_indptr_cpu: torch.Tensor,
paged_kv_indices: torch.Tensor, paged_kv_indices: torch.Tensor,
paged_kv_last_page_len_cpu: torch.Tensor, paged_kv_last_page_len_cpu: torch.Tensor,
prefill_start: int,
page_size: int, page_size: int,
num_qo_heads: int, num_qo_heads: int,
dcp_world_size: int, dcp_world_size: int,
...@@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper: ...@@ -200,7 +201,7 @@ class BatchDCPPrefillWrapper:
qo_indptr_cpu, qo_indptr_cpu,
paged_kv_indptr_cpu, paged_kv_indptr_cpu,
paged_kv_indices, paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:], paged_kv_last_page_len_cpu,
num_qo_heads * dcp_world_size, num_qo_heads * dcp_world_size,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
...@@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend): ...@@ -380,40 +381,103 @@ class FlashInferBackend(AttentionBackend):
@dataclass @dataclass
class FlashInferMetadata: class FIPrefill:
num_actual_tokens: int # Number of tokens excluding padding. """Metadata for the native FlashInfer prefill pathway (non-TRTLLM)."""
# The data type of the query wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
q_data_type: torch.dtype
slot_mapping: torch.Tensor
# For flashinfer trtllm batch decode @dataclass
class FIDecode:
"""Metadata for the native FlashInfer decode pathway (non-TRTLLM)."""
wrapper: BatchDecodeWithPagedKVCacheWrapper
@dataclass
class TRTLLMPrefill:
"""Metadata for the TRTLLM prefill pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to prefill requests.
Shape: [num_prefills, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor
"""
The slice of the sequence lengths tensor corresponding *only* to prefill requests.
Shape: [num_prefills]
"""
cum_seq_lens_q: torch.Tensor
cum_seq_lens_kv: torch.Tensor
max_q_len: int max_q_len: int
max_q_len_prefill: int """
The maximum query length *among prefill requests*.
"""
max_seq_len: int max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class TRTLLMDecode:
"""Metadata for the TRTLLM decode pathway."""
block_tables: torch.Tensor
"""
The slice of the block table tensor corresponding *only* to decode requests.
Shape: [num_decodes, max_num_blocks_per_seq]
"""
seq_lens: torch.Tensor seq_lens: torch.Tensor
block_table_tensor: torch.Tensor """
prefill_use_trtllm: bool The slice of the sequence lengths tensor corresponding *only* to decode requests.
decode_use_trtllm: bool Shape: [num_decodes]
"""
max_seq_len: int
"""The maximum sequence length for KV Cache."""
@dataclass
class FlashInferMetadata:
num_actual_tokens: int
"""Total number of tokens in the batch (excluding padding)."""
slot_mapping: torch.Tensor
"""Tensor for writing K/V to the cache. Shape: [num_actual_tokens]"""
q_data_type: torch.dtype
# For handling prefill decode split
num_decodes: int num_decodes: int
num_decode_tokens: int num_decode_tokens: int
num_prefills: int num_prefills: int
num_prefill_tokens: int num_prefill_tokens: int
# For cascade attention (CPU for planning). prefill: FIPrefill | TRTLLMPrefill | None
use_cascade: bool """
Holds the metadata for the prefill portion of the batch.
Will be `None` if `num_prefill_tokens == 0`.
"""
decode: FIDecode | TRTLLMDecode | None
"""
Holds the metadata for the decode portion of the batch.
Will be `None` if `num_decode_tokens == 0`.
"""
prefill_wrapper: ( # --- Special Case: Cascade Attention ---
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
) = None
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
qo_indptr_gpu: torch.Tensor | None = None use_cascade: bool
paged_kv_indptr_gpu: torch.Tensor | None = None """
If True, the entire batch is a cascade attention call, and the
`prefill` and `decode` fields will both be None.
"""
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...@@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -482,6 +546,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.dcp_world_size = 1 self.dcp_world_size = 1
self.dcp_rank = 0 self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = 1 self.dcp_kv_cache_interleave_size = 1
self.use_dcp = self.dcp_world_size > 1
self.num_qo_heads = self.model_config.get_num_attention_heads( self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config self.vllm_config.parallel_config
...@@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -535,34 +600,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"sinks, please use trtllm on blackwell or flash attention on " "sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs." "earlier GPUs."
) )
# Preparing persistent buffers (device-side) # Preparing persistent buffers
self.paged_kv_indptr = torch.zeros( self.pin_memory = is_pin_memory_available()
max_num_reqs + 1, dtype=torch.int32, device=self.device self.paged_kv_indptr = self._make_buffer(max_num_reqs + 1)
) self.paged_kv_indptr_cpu_buffer = torch.zeros_like(
self.paged_kv_indices = torch.zeros( self.paged_kv_indptr.cpu, pin_memory=self.pin_memory
max_num_pages, # max num pages possible ) # Extra buffer for mutable paged_kv_indptr.cpu in cuda graph mode
dtype=torch.int32, self.paged_kv_indices = self._make_buffer(max_num_pages)
device=self.device, self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)
)
self.paged_kv_last_page_len = torch.zeros(
max_num_reqs, dtype=torch.int32, device=self.device
)
# host-side buffer
pin_memory = is_pin_memory_available()
self.paged_kv_indptr_cpu = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory
)
self.paged_kv_indices_cpu = torch.zeros(
max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_cpu = torch.zeros(
max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
if self.head_dim == 256 and current_platform.is_device_capability_family(100): if self.head_dim == 256 and current_platform.is_device_capability_family(100):
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
...@@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -573,6 +618,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
"passing --block-size 32 or --block-size 64." "passing --block-size 32 or --block-size 64."
) )
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32
) -> CpuGpuBuffer:
return CpuGpuBuffer(
*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory,
with_numpy=True,
)
@override # type: ignore[misc]
@classmethod @classmethod
def get_cudagraph_support( def get_cudagraph_support(
cls: type["FlashInferMetadataBuilder"], cls: type["FlashInferMetadataBuilder"],
...@@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -607,7 +664,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self, self,
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper: ) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
if self.dcp_world_size > 1: if self.use_dcp:
self._prefill_wrapper = BatchDCPPrefillWrapper( self._prefill_wrapper = BatchDCPPrefillWrapper(
workspace_buffer=self._get_workspace_buffer(), workspace_buffer=self._get_workspace_buffer(),
) )
...@@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -626,9 +683,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if decode_wrapper is None: if decode_wrapper is None:
if use_cudagraph: if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indptr = self.paged_kv_indptr.gpu[: batch_size + 1]
paged_kv_indices = self.paged_kv_indices paged_kv_indices = self.paged_kv_indices.gpu
paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] paged_kv_last_page_len = self.paged_kv_last_page_len.gpu[:batch_size]
else: else:
paged_kv_indptr = None paged_kv_indptr = None
paged_kv_indices = None paged_kv_indices = None
...@@ -661,99 +718,43 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -661,99 +718,43 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
return self._cascade_wrapper return self._cascade_wrapper
def build( def _compute_flashinfer_kv_metadata(
self, self,
common_prefix_len: int, num_blocks_np: np.ndarray,
common_attn_metadata: CommonAttentionMetadata, seq_lens_np: np.ndarray,
fast_build: bool = False, block_table_tensor: torch.Tensor,
) -> FlashInferMetadata: num_reqs: int,
num_reqs = common_attn_metadata.num_reqs page_size: int,
num_actual_tokens = common_attn_metadata.num_actual_tokens ) -> torch.Tensor:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( """
split_decodes_and_prefills( Compute paged_kv_indptr, paged_kv_indices, paged_kv_last_page_len for FlashInfer
common_attn_metadata, attention.
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
if self.dcp_world_size > 1:
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
seq_lens_np = seq_lens_cpu.numpy()
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests. Results are stored in self.paged_kv_indptr,
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] self.paged_kv_indices, self.paged_kv_last_page_len buffers.
num_blocks_np -= num_common_kv_blocks
else:
shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None
Returns paged_kv_indices, a GPU tensor with shape [num_actual_pages].
"""
# write self.paged_kv_indptr_cpu inplace (0-index is always 0) # write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum( np.cumsum(
num_blocks_np, num_blocks_np,
dtype=np.int32, dtype=np.int32,
out=self.paged_kv_indptr_np[1 : num_reqs + 1], out=self.paged_kv_indptr.np[1 : num_reqs + 1],
) )
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to # after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition. # self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ self.paged_kv_indptr_cpu_buffer[: num_reqs + 1] = self.paged_kv_indptr.cpu[
: num_reqs + 1 : num_reqs + 1
] ]
paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1]
paged_kv_indptr.copy_( paged_kv_indptr.copy_(
self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True self.paged_kv_indptr_cpu_buffer[: num_reqs + 1], non_blocking=True
) )
# write self.paged_kv_indices inplace # write self.paged_kv_indices inplace
num_actual_pages = self.paged_kv_indptr_np[num_reqs] num_actual_pages = self.paged_kv_indptr.np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages] paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs,)]( _copy_page_indices_kernel[(num_reqs,)](
paged_kv_indices, paged_kv_indices,
block_table_tensor, block_table_tensor,
...@@ -764,12 +765,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -764,12 +765,41 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# write self.paged_kv_last_page_len_cpu inplace # write self.paged_kv_last_page_len_cpu inplace
paged_kv_last_page_len_np = seq_lens_np % page_size paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where( self.paged_kv_last_page_len.np[:num_reqs] = np.where(
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0), (paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
page_size, page_size,
paged_kv_last_page_len_np, paged_kv_last_page_len_np,
) )
return paged_kv_indices
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
page_size = self.page_size
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
qo_indptr = common_attn_metadata.query_start_loc
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
# Step 1: Decide which dispatch modes to use:
# - Cascade attention (distinct mode)
# - Prefill (FI native or TRTLLM)
# - Decode (FI native or TRTLLM)
use_cascade = common_prefix_len > 0
uses_spec_reorder = self.reorder_batch_threshold > 1 uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention( prefill_use_trtllm = use_trtllm_attention(
self.num_qo_heads, self.num_qo_heads,
...@@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -788,7 +818,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.use_trtllm_decode_attention and self.dcp_world_size <= 1 self.use_trtllm_decode_attention and self.dcp_world_size <= 1
) )
if not (prefill_use_trtllm and decode_use_trtllm): all_uses_trtllm = (num_prefills == 0 or prefill_use_trtllm) and (
num_decodes == 0 or decode_use_trtllm
)
is_only_trtllm_decode = num_prefills == 0 and (
num_decodes > 0 and decode_use_trtllm
)
if not all_uses_trtllm:
if self.has_sinks: if self.has_sinks:
raise NotImplementedError( raise NotImplementedError(
"FlashInfer backend currently does not support attention " "FlashInfer backend currently does not support attention "
...@@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -813,28 +850,102 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# fall back to model dtype. # fall back to model dtype.
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
# Step 2: Initialize the output metadata
# Leave prefill/decode/cascade_wrapper empty, to be populated
# case by case depending on the batch contents and backend selection.
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len, q_data_type=self.q_data_type,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
prefill_use_trtllm=prefill_use_trtllm,
decode_use_trtllm=decode_use_trtllm,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills, num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade, use_cascade=use_cascade,
prefill=None,
decode=None,
cascade_wrapper=None,
) )
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] # Guard access to seq_lens_cpu, which may not always be needed
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] # and can be expensive to retrieve in async mode.
needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode
seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None
seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None
num_blocks_np = (
(seq_lens_np + (page_size - 1)) // page_size
if seq_lens_np is not None
else None
)
# Adjust seq_lens_cpu for DCP
if self.use_dcp:
assert seq_lens_cpu is not None
if num_prefills > 0:
qo_indptr_prefill_cpu = (
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
)
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
)
seq_lens_cpu[num_decodes:] = (
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
)
seq_lens_cpu = get_dcp_local_seq_lens(
seq_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
# Adjust num_block_np for cascade attention
if use_cascade:
assert num_blocks_np is not None
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
num_blocks_np -= num_common_kv_blocks
# Compute paged_kv_indices if necessary
needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode
if needs_paged_kv_indices:
assert num_blocks_np is not None
assert seq_lens_np is not None
paged_kv_indices = self._compute_flashinfer_kv_metadata(
num_blocks_np,
seq_lens_np,
block_table_tensor,
num_reqs,
page_size,
)
else:
paged_kv_indices = None
# Early-out for cascade attention
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
num_common_kv_blocks = common_prefix_len // page_size
# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device="cpu"
)
shared_kv_page_indptr_cpu = torch.tensor(
[0, num_common_kv_blocks], dtype=torch.int32, device="cpu"
)
shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks]
shared_kv_last_page_len_cpu = torch.tensor(
[page_size], dtype=torch.int32, device="cpu"
)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
num_blocks_np -= num_common_kv_blocks
assert paged_kv_indices is not None
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[: 1 + num_reqs]
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan( attn_metadata.cascade_wrapper.plan(
[shared_qo_indptr_cpu, qo_indptr_cpu], [shared_qo_indptr_cpu, qo_indptr_cpu],
...@@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -852,91 +963,107 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype, kv_data_type=self.kv_cache_dtype,
) )
else: return attn_metadata
# Regular attention (common case).
# Decodes are at the front and prefills are at the back. # Step 3: Handle prefill and decode pathways case by case
num_prefills = attn_metadata.num_prefills ## PREFILL PATHWAY
num_decodes = attn_metadata.num_decodes if num_prefills > 0:
if num_prefills > 0: # Slices for shared prefill metadata
# Decodes are first so prefills start after the last decode prefill_start = num_decodes
prefill_start = num_decodes qo_indptr_prefill_cpu = (
attn_metadata.prefill_wrapper = self._get_prefill_wrapper() qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start]
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 )
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 assert qo_indptr_prefill_cpu.shape[0] == num_prefills + 1
assert (
paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills if prefill_use_trtllm:
# Create GPU versions
qo_indptr_prefill_gpu = (
qo_indptr[prefill_start:] - qo_indptr[prefill_start]
)
paged_kv_indptr_prefill_gpu = self.paged_kv_indptr.gpu[
prefill_start : num_reqs + 1
]
# Compute max_q_len for prefill requests
query_lens_prefill_cpu = (
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
) )
# Since prefill_wrapper.run() will be called with max_q_len_prefill = int(query_lens_prefill_cpu.max().item())
# query[num_decode_tokens:] we need to adjust the qo_indptr attn_metadata.prefill = TRTLLMPrefill(
# to be relative to the start of the prefill queries. block_tables=block_table_tensor[prefill_start:],
qo_indptr_cpu = ( seq_lens=seq_lens[prefill_start:],
qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] cum_seq_lens_q=qo_indptr_prefill_gpu,
cum_seq_lens_kv=paged_kv_indptr_prefill_gpu,
max_q_len=max_q_len_prefill,
max_seq_len=max_seq_len,
) )
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] else:
prefill_wrapper = self._get_prefill_wrapper()
# Recompute max_q_len for the slice of requests we are using # Slicing CPU buffers that are only needed for FI native prefills
# for prefills. This can be different from max_q_len when paged_kv_last_page_len_prefill_cpu = self.paged_kv_last_page_len.cpu[
# we have a non-uniform batch with some short decodes offloaded prefill_start:num_reqs
# to the prefill pathway ]
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] assert paged_kv_last_page_len_prefill_cpu.shape[0] == num_prefills
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) paged_kv_indptr_prefill_cpu = self.paged_kv_indptr.cpu[
prefill_start : num_reqs + 1
if not attn_metadata.prefill_use_trtllm: ]
if self.dcp_world_size > 1: assert paged_kv_indptr_prefill_cpu.shape[0] == num_prefills + 1
assert isinstance( if self.use_dcp:
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
) prefill_wrapper.plan(
attn_metadata.prefill_wrapper.plan( qo_indptr_cpu=qo_indptr_prefill_cpu,
qo_indptr_cpu=qo_indptr_cpu, paged_kv_indptr_cpu=paged_kv_indptr_prefill_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indices=paged_kv_indices,
paged_kv_indices=paged_kv_indices, paged_kv_last_page_len_cpu=paged_kv_last_page_len_prefill_cpu,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, page_size=self.page_size,
prefill_start=prefill_start, num_qo_heads=self.num_qo_heads,
page_size=self.page_size, dcp_world_size=self.dcp_world_size,
num_qo_heads=self.num_qo_heads, num_kv_heads=self.num_kv_heads,
dcp_world_size=self.dcp_world_size, head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads, sm_scale=self.sm_scale,
head_dim=self.head_dim, window_left=self.window_left,
sm_scale=self.sm_scale, logits_soft_cap=self.logits_soft_cap,
window_left=self.window_left, q_data_type=self.q_data_type,
logits_soft_cap=self.logits_soft_cap, kv_cache_dtype=self.kv_cache_dtype,
q_data_type=self.q_data_type, prefill_fixed_split_size=self.prefill_fixed_split_size,
kv_cache_dtype=self.kv_cache_dtype, disable_split_kv=self.disable_split_kv,
prefill_fixed_split_size=self.prefill_fixed_split_size, )
disable_split_kv=self.disable_split_kv,
)
else:
assert isinstance(
attn_metadata.prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu,
paged_kv_indptr_cpu,
paged_kv_indices,
paged_kv_last_page_len_cpu[prefill_start:],
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
else: else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( assert isinstance(
self.device, non_blocking=True prefill_wrapper,
BatchPrefillWithPagedKVCacheWrapper,
) )
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( prefill_wrapper.plan(
self.device, non_blocking=True qo_indptr_prefill_cpu,
paged_kv_indptr_prefill_cpu,
paged_kv_indices,
paged_kv_last_page_len_prefill_cpu,
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
causal=True,
sm_scale=self.sm_scale,
window_left=self.window_left,
logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
) )
attn_metadata.prefill = FIPrefill(wrapper=prefill_wrapper)
if num_decodes > 0: ## DECODE PATHWAY
if num_decodes > 0:
if decode_use_trtllm:
assert num_decode_tokens % num_decodes == 0, (
"TRTLLM decode requires uniform query lengths per request."
)
attn_metadata.decode = TRTLLMDecode(
block_tables=block_table_tensor[:num_decodes],
seq_lens=seq_lens[:num_decodes],
max_seq_len=max_seq_len,
)
else:
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
use_cudagraph = ( use_cudagraph = (
self.enable_cuda_graph self.enable_cuda_graph
...@@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -945,33 +1072,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
num_input_tokens = num_decode_tokens num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper( decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph num_input_tokens, use_cudagraph
) )
if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length,
# Use the persistent buffer with padding length, # instead of the same address but chunked version
# instead of the same address but chunked version # in atten_metadata when using cudagraph.
# in atten_metadata when using cudagraph. fast_plan_decode(
fast_plan_decode( decode_wrapper,
attn_metadata.decode_wrapper, self.paged_kv_indptr.cpu[: num_input_tokens + 1],
self.paged_kv_indptr_cpu[: num_input_tokens + 1], paged_kv_indices,
paged_kv_indices, self.paged_kv_last_page_len.cpu[:num_input_tokens],
self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens],
seq_lens_cpu[:num_input_tokens], self.num_qo_heads * self.dcp_world_size,
self.num_qo_heads * self.dcp_world_size, self.num_kv_heads,
self.num_kv_heads, self.head_dim,
self.head_dim, self.page_size,
self.page_size, # Disable flashinfer's pos encoding and use vllm's rope.
# Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE",
pos_encoding_mode="NONE", sm_scale=self.sm_scale,
sm_scale=self.sm_scale, window_left=self.window_left,
window_left=self.window_left, logits_soft_cap=self.logits_soft_cap,
logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type,
q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype,
kv_data_type=self.kv_cache_dtype, fixed_split_size=self.decode_fixed_split_size,
fixed_split_size=self.decode_fixed_split_size, disable_split_kv=self.disable_split_kv,
disable_split_kv=self.disable_split_kv, )
) attn_metadata.decode = FIDecode(wrapper=decode_wrapper)
return attn_metadata return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
...@@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl): ...@@ -1104,6 +1231,9 @@ class FlashInferImpl(AttentionImpl):
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
# The attn+quant fusion happens when output_scale is provided. # The attn+quant fusion happens when output_scale is provided.
if output_scale is None: if output_scale is None:
assert output_block_scale is None, ( assert output_block_scale is None, (
...@@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -1113,8 +1243,8 @@ class FlashInferImpl(AttentionImpl):
assert attn_metadata.q_data_type == FP8_DTYPE, ( assert attn_metadata.q_data_type == FP8_DTYPE, (
"Query must be FP8 when attn+quant fusion happened." "Query must be FP8 when attn+quant fusion happened."
) )
assert ( assert (attn_metadata.num_prefills == 0 or prefill_use_trtllm) and (
attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm attn_metadata.num_decodes == 0 or decode_use_trtllm
), "Must use TRT-LLM attn" ), "Must use TRT-LLM attn"
if output.dtype == FP8_DTYPE: if output.dtype == FP8_DTYPE:
...@@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl): ...@@ -1191,22 +1321,25 @@ class FlashInferImpl(AttentionImpl):
# When using spec decoding, num_decodes can be < num_decode_tokens # When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token. # because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order) kv_cache_permute = kv_cache.permute(*stride_order)
use_dcp = self.dcp_world_size > 1
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back. # Decodes are at the front and prefills are at the back.
if num_prefill_tokens > 0: if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:] prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
if not attn_metadata.prefill_use_trtllm: if not prefill_use_trtllm:
if self.dcp_world_size > 1: assert isinstance(attn_metadata.prefill, FIPrefill)
prefill_wrapper = attn_metadata.prefill.wrapper
assert prefill_wrapper is not None
if use_dcp:
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
assert prefill_wrapper._context._window_left == self.window_left assert prefill_wrapper._context._window_left == self.window_left
assert prefill_wrapper._context._logits_soft_cap == ( assert prefill_wrapper._context._logits_soft_cap == (
...@@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl): ...@@ -1247,11 +1380,12 @@ class FlashInferImpl(AttentionImpl):
out=output[num_decode_tokens:], out=output[num_decode_tokens:],
) )
else: else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous # prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous() prefill_query = prefill_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer() workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] seq_lens_prefill = attn_metadata.prefill.seq_lens
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
...@@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl): ...@@ -1298,13 +1432,13 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=mock_block_table, block_tables=mock_block_table,
seq_lens=seq_lens_prefill, seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len_prefill, max_q_len=attn_metadata.prefill.max_q_len,
max_kv_len=attn_metadata.max_seq_len, max_kv_len=attn_metadata.prefill.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu, cum_seq_lens_q=attn_metadata.prefill.cum_seq_lens_q,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, cum_seq_lens_kv=attn_metadata.prefill.cum_seq_lens_kv,
window_left=self.window_left, window_left=self.window_left,
sinks=self.sinks, sinks=self.sinks,
o_sf_scale=self.o_sf_scale, o_sf_scale=self.o_sf_scale,
...@@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl): ...@@ -1312,17 +1446,18 @@ class FlashInferImpl(AttentionImpl):
) )
if num_decode_tokens > 0: if num_decode_tokens > 0:
decode_wrapper = attn_metadata.decode_wrapper
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not attn_metadata.decode_use_trtllm: if not decode_use_trtllm:
assert isinstance(attn_metadata.decode, FIDecode)
decode_wrapper = attn_metadata.decode.wrapper
assert decode_wrapper is not None
assert decode_wrapper._window_left == self.window_left assert decode_wrapper._window_left == self.window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
assert decode_wrapper._sm_scale == self.scale assert decode_wrapper._sm_scale == self.scale
if self.dcp_world_size > 1: if use_dcp:
decode_query = get_dcp_group().all_gather( decode_query = get_dcp_group().all_gather(
decode_query.contiguous(), dim=-2 decode_query.contiguous(), dim=-2
) )
...@@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl): ...@@ -1357,12 +1492,11 @@ class FlashInferImpl(AttentionImpl):
) )
else: else:
# decode_query may be non-contiguous # decode_query may be non-contiguous
assert isinstance(attn_metadata.decode, TRTLLMDecode)
decode_query = decode_query.contiguous() decode_query = decode_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer() workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.block_table_tensor[ block_tables_decode = attn_metadata.decode.block_tables
:num_decode_tokens seq_lens_decode = attn_metadata.decode.seq_lens
]
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
...@@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1397,7 +1531,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=block_tables_decode, block_tables=block_tables_decode,
seq_lens=seq_lens_decode, seq_lens=seq_lens_decode,
max_seq_len=attn_metadata.max_seq_len, max_seq_len=attn_metadata.decode.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
window_left=self.window_left, window_left=self.window_left,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
...@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata: ...@@ -134,6 +135,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
): ):
supports_update_block_table: bool = True
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
...@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder( ...@@ -346,3 +349,27 @@ class Mamba2AttentionMetadataBuilder(
num_computed_tokens_p=num_computed_tokens_p, num_computed_tokens_p=num_computed_tokens_p,
) )
return attn_metadata return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata
...@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -15,6 +15,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): ...@@ -51,6 +52,8 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
qo_indptr: torch.Tensor | None = None qo_indptr: torch.Tensor | None = None
# The dtype of MLA out tensor # The dtype of MLA out tensor
attn_out_dtype: torch.dtype = torch.bfloat16 attn_out_dtype: torch.dtype = torch.bfloat16
# The max query output length: int
max_qo_len: int | None = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
...@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): ...@@ -60,9 +63,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
# TODO(luka, lucas): audit this as part of: # TODO(luka, lucas): audit this as part of:
# https://github.com/vllm-project/vllm/issues/22945 # https://github.com/vllm-project/vllm/issues/22945
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
)
def __init__( def __init__(
self, self,
...@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -97,8 +99,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_reqs, dtype=torch.int32, device=device max_num_reqs, dtype=torch.int32, device=device
) )
self.qo_indptr = torch.arange( self.qo_indptr = torch.zeros(
0, max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
) )
def _build_decode( def _build_decode(
...@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -128,6 +130,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
seq_lens_device.cumsum(dim=0, dtype=torch.int32), seq_lens_device.cumsum(dim=0, dtype=torch.int32),
] ]
) )
qo_len = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
max_qo_len = qo_len.max().item()
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
...@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -150,6 +154,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
self.paged_kv_last_page_len[num_reqs:].fill_(1) self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
self.qo_indptr[: 1 + num_reqs].copy_(
query_start_loc_device, non_blocking=True
)
self.qo_indptr[1 + num_reqs :] = query_start_loc_device[-1]
qo_indptr = self.qo_indptr[: 1 + num_reqs] qo_indptr = self.qo_indptr[: 1 + num_reqs]
else: else:
...@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -165,6 +173,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
paged_kv_last_page_len=paged_kv_last_page_len, paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr, qo_indptr=qo_indptr,
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_qo_len=max_qo_len,
attn_out_dtype=self.decode_attn_out_dtype, attn_out_dtype=self.decode_attn_out_dtype,
) )
...@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -255,16 +264,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
rocm_aiter_ops.mla_decode_fwd( rocm_aiter_ops.mla_decode_fwd(
q, q,
kv_buffer, kv_buffer,
o, o,
self.scale, self.scale,
attn_metadata.decode.qo_indptr, attn_metadata.decode.qo_indptr,
max_seqlen_qo, attn_metadata.decode.max_qo_len,
attn_metadata.decode.paged_kv_indptr, attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices, attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len, attn_metadata.decode.paged_kv_last_page_len,
......
...@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat ...@@ -152,7 +152,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend): class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
...@@ -165,7 +169,7 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -165,7 +169,7 @@ class RocmAttentionBackend(AttentionBackend):
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. " f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. " f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-config.backend=FLEX_ATTENTION to use " "Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )
......
...@@ -4,6 +4,7 @@ import abc ...@@ -4,6 +4,7 @@ import abc
import enum import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass from dataclasses import dataclass, field, fields, make_dataclass
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -201,10 +202,11 @@ def _make_metadata_with_slice( ...@@ -201,10 +202,11 @@ def _make_metadata_with_slice(
) )
# NOTE: last token can be outside of the last request if we have CG padding. # NOTE: last token can be outside of the last request if we have CG padding.
# If the "middle" request has tokens in both ubatches, we have to split it. # If the request is split across ubatches, we have to adjust the metadata.
# If ubatch_slice is the first ubatch then we will be splitting the last # splits_first_request: The first request in this slice is the continuation of
# request. If it's the second microbatch, then we will be splitting the # a request that started in a previous slice.
# first request # splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req] splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1 splits_last_request = last_tok < start_locs[last_req + 1] - 1
...@@ -225,7 +227,10 @@ def _make_metadata_with_slice( ...@@ -225,7 +227,10 @@ def _make_metadata_with_slice(
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request: if splits_last_request:
tokens_skipped = query_start_loc_cpu[-1] - token_slice.stop # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped query_start_loc_cpu[-1] -= tokens_skipped
...@@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -313,6 +318,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch. # length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod @abstractmethod
def __init__( def __init__(
...@@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -383,6 +391,21 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata self, common_attn_metadata: CommonAttentionMetadata
) -> M: ) -> M:
...@@ -599,7 +622,7 @@ def make_local_attention_virtual_batches( ...@@ -599,7 +622,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int, attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0, block_size: int = 0,
) -> CommonAttentionMetadata: ) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
...@@ -711,9 +734,12 @@ def make_local_attention_virtual_batches( ...@@ -711,9 +734,12 @@ def make_local_attention_virtual_batches(
# tensor first, which recovers perf. # tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices) batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices) block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch].view(
virtual_batches, -1 # Save as a lambda so we can return this for update_block_table
) make_block_table = lambda block_table: block_table[
batch_indices_torch, block_indices_torch
].view(virtual_batches, -1)
block_table_local = make_block_table(block_table)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local)
...@@ -732,7 +758,7 @@ def make_local_attention_virtual_batches( ...@@ -732,7 +758,7 @@ def make_local_attention_virtual_batches(
causal=True, causal=True,
_seq_lens_cpu=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), _num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
) ), make_block_table
def make_kv_sharing_fast_prefill_common_attn_metadata( def make_kv_sharing_fast_prefill_common_attn_metadata(
......
...@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu ...@@ -43,6 +43,7 @@ from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_qu
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import ( from vllm.v1.metrics.stats import (
PrefixCacheStats, PrefixCacheStats,
SchedulerStats, SchedulerStats,
...@@ -187,6 +188,12 @@ class Scheduler(SchedulerInterface): ...@@ -187,6 +188,12 @@ class Scheduler(SchedulerInterface):
if self.is_encoder_decoder if self.is_encoder_decoder
else EncoderCacheManager(cache_size=encoder_cache_size) else EncoderCacheManager(cache_size=encoder_cache_size)
) )
# For encoder-decoder models, allocate the maximum number of tokens for Cross
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self._num_encoder_max_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
)
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
self.use_eagle = False self.use_eagle = False
...@@ -213,6 +220,10 @@ class Scheduler(SchedulerInterface): ...@@ -213,6 +220,10 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.perf_metrics: ModelMetrics | None = None
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
self.perf_metrics = ModelMetrics(vllm_config)
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
...@@ -568,17 +579,11 @@ class Scheduler(SchedulerInterface): ...@@ -568,17 +579,11 @@ class Scheduler(SchedulerInterface):
0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens
) )
# Determine if we need to allocate cross-attention blocks. num_encoder_tokens = (
if self.is_encoder_decoder and request.has_encoder_inputs: self._num_encoder_max_input_tokens
# TODO(russellb): For Whisper, we know that the input is if self.is_encoder_decoder and request.has_encoder_inputs
# always padded to the maximum length. If we support other else 0
# encoder-decoder models, this will need to be updated if we )
# want to only allocate what is needed.
num_encoder_tokens = (
self.scheduler_config.max_num_encoder_input_tokens
)
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
...@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface): ...@@ -1066,6 +1071,10 @@ class Scheduler(SchedulerInterface):
kv_connector_output = model_runner_output.kv_connector_output kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats = model_runner_output.cudagraph_stats cudagraph_stats = model_runner_output.cudagraph_stats
perf_stats: PerfStats | None = None
if self.perf_metrics and self.perf_metrics.is_enabled():
perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output)
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = ( kv_connector_stats: KVConnectorStats | None = (
...@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface): ...@@ -1262,7 +1271,7 @@ class Scheduler(SchedulerInterface):
if ( if (
stats := self.make_stats( stats := self.make_stats(
spec_decoding_stats, kv_connector_stats, cudagraph_stats spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats
) )
) is not None: ) is not None:
# Return stats to only one of the front-ends. # Return stats to only one of the front-ends.
...@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface): ...@@ -1485,6 +1494,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats: SpecDecodingStats | None = None, spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: KVConnectorStats | None = None, kv_connector_stats: KVConnectorStats | None = None,
cudagraph_stats: CUDAGraphStat | None = None, cudagraph_stats: CUDAGraphStat | None = None,
perf_stats: PerfStats | None = None,
) -> SchedulerStats | None: ) -> SchedulerStats | None:
if not self.log_stats: if not self.log_stats:
return None return None
...@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface): ...@@ -1510,6 +1520,7 @@ class Scheduler(SchedulerInterface):
spec_decoding_stats=spec_stats, spec_decoding_stats=spec_stats,
kv_connector_stats=connector_stats_payload, kv_connector_stats=connector_stats_payload,
cudagraph_stats=cudagraph_stats, cudagraph_stats=cudagraph_stats,
perf_stats=perf_stats,
) )
def make_spec_decoding_stats( def make_spec_decoding_stats(
......
...@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ( from vllm.v1.engine import (
EngineCoreOutput,
EngineCoreOutputs, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
FinishReason,
ReconfigureDistributedRequest, ReconfigureDistributedRequest,
ReconfigureRankType, ReconfigureRankType,
UtilityOutput, UtilityOutput,
...@@ -923,6 +925,13 @@ class EngineCoreProc(EngineCore): ...@@ -923,6 +925,13 @@ class EngineCoreProc(EngineCore):
# Post-step hook. # Post-step hook.
self.post_step(model_executed) self.post_step(model_executed)
# If no model execution happened but there are waiting requests
# (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
# background threads (like NIXL handshake) to make progress.
# Without this, the tight polling loop can starve background threads.
if not model_executed and self.scheduler.has_unfinished_requests():
time.sleep(0.001)
return model_executed return model_executed
def _handle_client_request( def _handle_client_request(
...@@ -1048,9 +1057,14 @@ class EngineCoreProc(EngineCore): ...@@ -1048,9 +1057,14 @@ class EngineCoreProc(EngineCore):
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data. # Deserialize the request data.
request: Any
if request_type == EngineCoreRequestType.ADD: if request_type == EngineCoreRequestType.ADD:
request = add_request_decoder.decode(data_frames) req: EngineCoreRequest = add_request_decoder.decode(data_frames)
request = self.preprocess_add_request(request) try:
request = self.preprocess_add_request(req)
except Exception:
self._handle_request_preproc_error(req)
continue
else: else:
request = generic_decoder.decode(data_frames) request = generic_decoder.decode(data_frames)
...@@ -1134,6 +1148,30 @@ class EngineCoreProc(EngineCore): ...@@ -1134,6 +1148,30 @@ class EngineCoreProc(EngineCore):
# Limit the number of buffers to reuse. # Limit the number of buffers to reuse.
reuse_buffers.append(buffer) reuse_buffers.append(buffer)
def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None:
"""Log and return a request-scoped error response for exceptions raised
from the add request preprocessing in the input socket processing thread.
"""
logger.exception(
"Unexpected error pre-processing request %s", request.request_id
)
self.output_queue.put_nowait(
(
request.client_index,
EngineCoreOutputs(
engine_index=self.engine_index,
finished_requests={request.request_id},
outputs=[
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=FinishReason.ERROR,
)
],
),
)
)
class DPEngineCoreProc(EngineCoreProc): class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process """ZMQ-wrapper for running EngineCore in background process
......
...@@ -269,7 +269,8 @@ class InprocClient(EngineCoreClient): ...@@ -269,7 +269,8 @@ class InprocClient(EngineCoreClient):
self.engine_core = EngineCore(*args, **kwargs) self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step_fn() outputs, model_executed = self.engine_core.step_fn()
self.engine_core.post_step(model_executed=model_executed)
return outputs and outputs.get(0) or EngineCoreOutputs() return outputs and outputs.get(0) or EngineCoreOutputs()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
......
...@@ -24,7 +24,10 @@ from vllm.tokenizers.mistral import MistralTokenizer ...@@ -24,7 +24,10 @@ from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features,
validate_guidance_grammar,
)
from vllm.v1.structured_output.backend_lm_format_enforcer import ( from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer, validate_structured_output_request_lm_format_enforcer,
) )
...@@ -340,8 +343,22 @@ class InputProcessor: ...@@ -340,8 +343,22 @@ class InputProcessor:
# The request either failed validation # The request either failed validation
# or includes some jsonschema feature(s) that # or includes some jsonschema feature(s) that
# are not supported in xgrammar. # are not supported in xgrammar.
if isinstance(self.tokenizer, MistralTokenizer):
# Check if schema has features unsupported by guidance
so_params = params.structured_outputs
skip_guidance = False
if so_params.json:
if isinstance(so_params.json, str):
import json
schema = json.loads(so_params.json)
else:
schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(self.tokenizer, MistralTokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral # Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(params) validate_structured_output_request_outlines(params)
params.structured_outputs._backend = "outlines" params.structured_outputs._backend = "outlines"
else: else:
......
...@@ -8,6 +8,7 @@ from typing import Any, cast ...@@ -8,6 +8,7 @@ from typing import Any, cast
import torch import torch
from vllm.lora.request import LoRARequest
from vllm.outputs import ( from vllm.outputs import (
CompletionOutput, CompletionOutput,
PoolingOutput, PoolingOutput,
...@@ -93,7 +94,7 @@ class RequestState: ...@@ -93,7 +94,7 @@ class RequestState:
request_id: str, request_id: str,
parent_req: ParentRequest | None, parent_req: ParentRequest | None,
request_index: int, request_index: int,
lora_name: str | None, lora_request: LoRARequest | None,
output_kind: RequestOutputKind, output_kind: RequestOutputKind,
prompt: str | None, prompt: str | None,
prompt_token_ids: list[int] | None, prompt_token_ids: list[int] | None,
...@@ -112,7 +113,8 @@ class RequestState: ...@@ -112,7 +113,8 @@ class RequestState:
self.request_id = request_id self.request_id = request_id
self.parent_req = parent_req self.parent_req = parent_req
self.request_index = request_index self.request_index = request_index
self.lora_name = lora_name self.lora_request = lora_request
self.lora_name = lora_request.lora_name if lora_request is not None else None
self.output_kind = output_kind self.output_kind = output_kind
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
...@@ -178,9 +180,7 @@ class RequestState: ...@@ -178,9 +180,7 @@ class RequestState:
request_id=request.request_id, request_id=request.request_id,
parent_req=parent_req, parent_req=parent_req,
request_index=request_index, request_index=request_index,
lora_name=( lora_request=request.lora_request,
request.lora_request.name if request.lora_request is not None else None
),
output_kind=output_kind, output_kind=output_kind,
prompt=prompt, prompt=prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
...@@ -289,6 +289,7 @@ class RequestState: ...@@ -289,6 +289,7 @@ class RequestState:
return RequestOutput( return RequestOutput(
request_id=request_id, request_id=request_id,
lora_request=self.lora_request,
prompt=self.prompt, prompt=self.prompt,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
......
...@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( ...@@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.perf import PerfMetricsLogging
from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import ( from vllm.v1.metrics.stats import (
CachingMetrics, CachingMetrics,
...@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -118,6 +119,9 @@ class LoggingStatLogger(StatLoggerBase):
self.engine_is_idle = False self.engine_is_idle = False
self.aggregated = False self.aggregated = False
if self._enable_perf_stats():
self.perf_metrics_logging = PerfMetricsLogging(vllm_config)
def _reset(self, now): def _reset(self, now):
self.last_log_time = now self.last_log_time = now
...@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -127,6 +131,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_corrupted_reqs: int = 0 self.num_corrupted_reqs: int = 0
self.num_preemptions: int = 0 self.num_preemptions: int = 0
def _enable_perf_stats(self) -> bool:
return self.vllm_config.observability_config.enable_mfu_metrics
def _track_iteration_stats(self, iteration_stats: IterationStats): def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters. # Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens self.num_prompt_tokens += iteration_stats.num_prompt_tokens
...@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -175,6 +182,8 @@ class LoggingStatLogger(StatLoggerBase):
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats) self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
if not self.aggregated: if not self.aggregated:
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
if (perf_stats := scheduler_stats.perf_stats) and self._enable_perf_stats():
self.perf_metrics_logging.observe(perf_stats)
if mm_cache_stats: if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats) self.mm_caching_metrics.observe(mm_cache_stats)
...@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -211,7 +220,7 @@ class LoggingStatLogger(StatLoggerBase):
"Running: %d reqs", "Running: %d reqs",
"Waiting: %d reqs", "Waiting: %d reqs",
] ]
log_args = [ log_args: list[int | float | str] = [
self.last_prompt_throughput, self.last_prompt_throughput,
self.last_generation_throughput, self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs, self.last_scheduler_stats.num_running_reqs,
...@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -254,6 +263,8 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn)
if self.cudagraph_logging is not None: if self.cudagraph_logging is not None:
self.cudagraph_logging.log(log_fn=log_fn) self.cudagraph_logging.log(log_fn=log_fn)
if self._enable_perf_stats():
self.perf_metrics_logging.log(log_fn=log_fn, log_prefix=self.log_prefix)
def log_engine_initialized(self): def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks: if self.vllm_config.cache_config.num_gpu_blocks:
...@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase): ...@@ -282,6 +293,10 @@ class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase):
def log_prefix(self): def log_prefix(self):
return "{} Engines Aggregated: ".format(len(self.engine_indexes)) return "{} Engines Aggregated: ".format(len(self.engine_indexes))
def _enable_perf_stats(self) -> bool:
# Adding per_gpu perf stats across engines can lead to misleading numbers.
return False
def record( def record(
self, self,
scheduler_stats: SchedulerStats | None, scheduler_stats: SchedulerStats | None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Analytic flops/memory estimation module for transformer components,
to help derive MFU (Model Flops Utilization) stats for a running model.
"""
import json
import time
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from typing import Any, Protocol
import torch
from pydantic import BaseModel, Field, ValidationError, model_validator
from typing_extensions import Self
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_dtype_size,
get_kv_cache_torch_dtype,
)
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
class InvalidComponent(Exception):
"""
Custom exception to indicate that a certain ComponentMetric is not
applicable to the given VllmConfig.
"""
pass
#### Basic Data Types ####
@dataclass
class DebugPerfStats:
## Stats for debugging the metrics calculation
calc_duration: float = 0.0 # time spent calculating these stats
num_prefill_requests: int = 0
num_decode_requests: int = 0
context_breakdown: dict[str, int] | None = None
num_flops_per_gpu_breakdown: dict[str, int] | None = None
num_read_bytes_per_gpu_breakdown: dict[str, int] | None = None
num_write_bytes_per_gpu_breakdown: dict[str, int] | None = None
@dataclass
class PerfStats:
num_flops_per_gpu: int = 0
num_read_bytes_per_gpu: int = 0
num_write_bytes_per_gpu: int = 0
debug_stats: DebugPerfStats | None = None
@dataclass
class ExecutionContext:
"""
Represents an execution context for a batch of requests.
This class aggregates statistics across multiple requests in a batch,
separately tracking prefill and decode phases.
Example)
- Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context):
ctx = ExecutionContext()
ctx.add(2048, 2048, is_prefill=True)
ctx.add(1, 8192, is_prefill=False)
"""
# Prefill phase statistics
num_prefill_requests: int = 0
prefill_num_tokens: int = 0 # sum of num_tokens for prefill requests
prefill_context_len: int = 0 # sum of context_len for prefill requests
prefill_token_context_product: int = 0 # sum of (num_tokens * context_len)
# Decode phase statistics
num_decode_requests: int = 0
decode_num_tokens: int = 0 # sum of num_tokens for decode requests
decode_context_len: int = 0 # sum of context_len for decode requests
decode_token_context_product: int = 0 # sum of (num_tokens * context_len)
def add(self, num_tokens: int, context_len: int, is_prefill: bool) -> None:
"""Add a single request's statistics to this batch context."""
if is_prefill:
self.num_prefill_requests += 1
self.prefill_num_tokens += num_tokens
self.prefill_context_len += context_len
self.prefill_token_context_product += num_tokens * context_len
else:
self.num_decode_requests += 1
self.decode_num_tokens += num_tokens
self.decode_context_len += context_len
self.decode_token_context_product += num_tokens * context_len
def total_num_tokens(self) -> int:
"""Total number of tokens across all requests in the batch."""
return self.prefill_num_tokens + self.decode_num_tokens
def total_token_context_product(self) -> int:
"""Total sum of (num_tokens * context_len) across all requests."""
return self.prefill_token_context_product + self.decode_token_context_product
@classmethod
def from_single_request(
cls, num_tokens: int, context_len: int, is_prefill: bool
) -> "ExecutionContext":
"""Create an ExecutionContext from a single request.
This is a convenience method primarily for testing.
"""
ctx = cls()
ctx.add(num_tokens, context_len, is_prefill)
return ctx
class ParsedArgs:
"""
Syntactic sugar so that Parsers can use dot notations
to access/update the parsed arguments.
e.g.)
args = ParsedArgs()
args.x = 3
args.y = args.x + 1
"""
def __getattr__(self, name: str) -> Any:
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
object.__setattr__(self, name, value)
def model_dump(self) -> dict[str, Any]:
return vars(self).copy()
#### Abstract ####
class Parser(Protocol):
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
"""
Parse the vllm config and update the current ParsedArgs and pass it on.
If the parser isn't applicable to the vllm_config, it will do nothing.
"""
...
class ParserChain:
"""
Applies chain of parser in a sequential order.
Later parsers might overwrite results from previous parsers,
so parsers should be chained in the appropriate order if they
are not mutually exclusive.
"""
def __init__(self, *parsers: Parser) -> None:
self.parsers = list(parsers)
def add_parser(self, parser: Parser) -> None:
self.parsers.append(parser)
def parse(self, vllm_config: VllmConfig) -> ParsedArgs:
args = ParsedArgs()
for parser in self.parsers:
args = parser.parse(args, vllm_config)
return args
_COMPONENT_METRICS_REGISTRY: dict[str, type["ComponentMetrics"]] = {}
class ComponentMetrics(BaseModel, ABC):
"""
Each concrete ComponentMetrics class is associated with:
- fields that are required for metric derivation
(fields are specified/validated through pydantic model)
- parser to parse VllmConfig into fields
- metric methods that derive flops/bytes for a given execution context
"""
@classmethod
@abstractmethod
def component_type(cls) -> str: ...
@classmethod
@abstractmethod
def get_parser(cls) -> ParserChain:
"""
Return a ParserChain that provides values for all required fields.
The returned parser chain must populate ParsedArgs with values for every
field defined on this ComponentMetrics class. Missing fields will cause
a ValidationError when from_vllm_config() is called.
See individual Parser docstrings for which args they provide, and field
comments on ComponentMetrics subclasses for which parser provides each field.
"""
...
def __init_subclass__(cls):
_COMPONENT_METRICS_REGISTRY[cls.component_type()] = cls
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig) -> Self:
"""
Instantiate this class from VllmConfig.
Raises ValidationError if parsing fails.
"""
parser = cls.get_parser()
parsed_args = parser.parse(vllm_config)
try:
return cls.model_validate(parsed_args.model_dump())
except ValidationError as e:
raise InvalidComponent(f"Invalid {cls.component_type()} config: {e}") from e
@classmethod
def registered_metrics(cls) -> Iterable[type["ComponentMetrics"]]:
return iter(_COMPONENT_METRICS_REGISTRY.values())
@abstractmethod
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
@abstractmethod
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
@abstractmethod
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]: ...
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_num_flops_breakdown(ctx, per_gpu).values())
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_read_bytes_breakdown(ctx, per_gpu).values())
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(self.get_write_bytes_breakdown(ctx, per_gpu).values())
#### parsers ####
class BaseConfigParser(Parser):
"""
Parses base model configuration.
Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers,
weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
model_config = vllm_config.model_config
args.vocab_size = model_config.get_vocab_size()
args.hidden_size = model_config.get_hidden_size()
# NOTE: model_config.get_attention_heads() divide by TP
# so we access field manually here to get total num_heads
args.num_attention_heads = get_required(
model_config.hf_text_config, "num_attention_heads"
)
args.num_hidden_layers = get_required(
model_config.hf_text_config, "num_hidden_layers"
)
model_dtype = vllm_config.model_config.dtype
if isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
elif isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
# FIXME: handle this better
logger.warning(
"Unknown model_dtype %s, defaulting to bfloat16",
model_dtype,
)
torch_dtype = torch.bfloat16
args.weight_byte_size = get_dtype_size(torch_dtype)
# FIXME: handle this better by parsing whether activations use
# bf16, fp32, etc...
args.activation_byte_size = 2
args.dp_size = vllm_config.parallel_config.data_parallel_size
args.tp_size = vllm_config.parallel_config.tensor_parallel_size
args.pp_size = vllm_config.parallel_config.pipeline_parallel_size
args.enable_ep = vllm_config.parallel_config.enable_expert_parallel
return args
#### Attention ####
class BaseAttentionConfigParser(Parser):
"""
Parses attention-specific configuration.
Provides: num_key_value_heads, head_dim, cache_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
model_config = vllm_config.model_config
args.num_key_value_heads = model_config.get_total_num_kv_heads()
args.head_dim = model_config.get_head_size()
model_dtype = vllm_config.model_config.dtype
cache_dtype = vllm_config.cache_config.cache_dtype
kv_cache_torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
args.cache_byte_size = get_dtype_size(kv_cache_torch_dtype)
return args
class AttentionQuantizationConfigParser(Parser):
"""
Parses quantization configuration for attention layers.
Overrides: weight_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.quant_config
if cfg is None:
return args
quant_method = cfg.get_name()
if quant_method in ["fp8", "fbgemm_fp8"]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args.weight_byte_size = 1
elif quant_method == "mxfp4":
# FIXME: Also has "ignored layers" issue above
args.weight_byte_size = 0.5
else:
# FIXME: Add more parsing logic for different quant methods.
raise InvalidComponent
return args
class AttentionMetrics(ComponentMetrics):
# From BaseConfigParser
num_hidden_layers: int = Field(..., gt=0)
hidden_size: int = Field(..., gt=0)
num_attention_heads: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
tp_size: int = Field(..., gt=0)
pp_size: int = Field(..., gt=0)
# From BaseAttentionConfigParser
num_key_value_heads: int = Field(..., gt=0)
head_dim: int = Field(..., gt=0)
cache_byte_size: int = Field(..., gt=0)
# From BaseConfig Parser, overridden by AttentionQuantizationConfigParser
weight_byte_size: int | float = Field(..., gt=0)
# TODO: discern cases where we have mixture of different attention layer types
# such as SWA, MLA, etc.
@classmethod
def component_type(cls) -> str:
return "attn"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
BaseAttentionConfigParser(),
AttentionQuantizationConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
TC = ctx.total_token_context_product()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
return {
"qkv_proj": 2 * T * D * (q + 2 * kv) * d * L,
"attn_qk": 2 * q * TC * d * L,
"attn_av": 2 * q * TC * d * L,
"out_proj": 2 * T * D * q * d * L,
}
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
read_bytes = {}
read_bytes["qkv_input"] = T * D * self.activation_byte_size * L
read_bytes["qkv_weight"] = int(D * (q + 2 * kv) * d * self.weight_byte_size * L)
# Attention input reads differ between prefill and decode
# Prefill: read Q, K, V activations (all in activation_byte_size)
if ctx.prefill_num_tokens > 0:
read_bytes["attn_input"] = (
(ctx.prefill_num_tokens * q + 2 * ctx.prefill_context_len * kv)
* d
* self.activation_byte_size
* L
)
# Decode: read Q activations + read K, V from cache (in cache_byte_size)
if ctx.decode_num_tokens > 0:
read_bytes["attn_input"] = read_bytes.get("attn_input", 0) + (
ctx.decode_num_tokens * q * d * self.activation_byte_size * L
+ 2 * ctx.decode_context_len * kv * d * self.cache_byte_size * L
)
read_bytes["out_input"] = T * q * d * self.activation_byte_size * L
read_bytes["out_weight"] = int(q * d * D * self.weight_byte_size * L)
return read_bytes
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for attention layers."""
L, D, q, kv, d = (
self.num_hidden_layers,
self.hidden_size,
self.num_attention_heads,
self.num_key_value_heads,
self.head_dim,
)
T = ctx.total_num_tokens()
if per_gpu:
L //= self.pp_size
# tensor parallel along heads
q = max(1, q // self.tp_size)
kv = max(1, kv // self.tp_size)
return {
"qkv_output": T * (q + 2 * kv) * d * self.activation_byte_size * L,
"kv_cache": 2 * T * kv * d * self.cache_byte_size * L,
"out_output": T * D * self.activation_byte_size * L,
}
#### Ffn ####
class BaseFfnConfigParser(Parser):
"""
Parses FFN and MoE configuration.
Provides: intermediate_size, num_experts, num_experts_per_tok,
moe_intermediate_size, num_shared_experts, num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
args.intermediate_size = getattr(cfg, "intermediate_size", args.hidden_size * 4)
# Try different naming conventions.
args.num_experts = vllm_config.model_config.get_num_experts()
args.num_experts_per_tok = getattr_from_list(
cfg, ["num_experts_per_tok", "moe_topk"], 0
)
args.moe_intermediate_size = getattr_from_list(
cfg, ["moe_intermediate_size", "intermediate_size"], 0
)
args.num_shared_experts = getattr_from_list(
cfg, ["n_shared_experts", "num_shared_experts"], 0
)
is_moe = args.num_experts != 0
# Assume all MoE layers by default
args.num_moe_layers = args.num_hidden_layers if is_moe else 0
return args
class FfnParallelParser(Parser):
"""
Parses FFN parallelism configuration.
Provides: ffn_tp_size, ffn_ep_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
# NOTE: ffn tp_size does not equal the tp_size parameter directly.
# e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.)
if args.enable_ep:
ffn_tp_size, ffn_ep_size = 1, args.dp_size * args.tp_size
else:
ffn_tp_size, ffn_ep_size = args.dp_size * args.tp_size, 1
args.ffn_tp_size = ffn_tp_size
args.ffn_ep_size = ffn_ep_size
return args
class InterleaveMoeLayerStepParser(Parser):
"""
Parses interleave_moe_layer_step field for models like Llama4.
Overrides: num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
if (
hasattr(cfg, "interleave_moe_layer_step")
and cfg.interleave_moe_layer_step > 0
):
args.num_moe_layers = len(
[
layer
for layer in range(args.num_hidden_layers)
if (layer + 1) % cfg.interleave_moe_layer_step == 0
]
)
return args
class MoeLayerFreqParser(Parser):
"""
Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek.
Overrides: num_moe_layers
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.model_config.hf_config
if hasattr(cfg, "text_config") and cfg.text_config is not None:
cfg = cfg.text_config
if hasattr(cfg, "moe_layer_freq") and hasattr(cfg, "first_k_dense_replace"):
args.num_moe_layers = len(
[
layer
for layer in range(args.num_hidden_layers)
if layer >= cfg.first_k_dense_replace
and layer % cfg.moe_layer_freq == 0
]
)
return args
class FfnQuantizationConfigParser(Parser):
"""
Parses quantization configuration for FFN layers.
Overrides: weight_byte_size
"""
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
cfg = vllm_config.quant_config
if cfg is None:
return args
quant_method = cfg.get_name()
if quant_method in ["fp8", "fbgemm_fp8"]:
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
# (there might be more quantization methods for fp8).
# FIXME: These configs also have concept of "ignored layers" and we
# need to solve the same problem as above.
args.weight_byte_size = 1
pass
elif quant_method == "mxfp4":
# FIXME: Also has "ignored layers" issue above
args.weight_byte_size = 0.5
else:
# FIXME: Add more parsing logic for different quant methods.
raise InvalidComponent
return args
class FfnMetrics(ComponentMetrics):
# From BaseConfigParser
num_hidden_layers: int = Field(..., gt=0)
hidden_size: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
pp_size: int = Field(..., gt=0)
# From FfnParallelParser
ffn_tp_size: int = Field(..., gt=0)
ffn_ep_size: int = Field(..., gt=0)
# From BaseFfnConfigParser
intermediate_size: int = Field(..., gt=0)
num_experts: int = Field(0)
num_experts_per_tok: int = Field(1)
moe_intermediate_size: int = Field(0)
num_shared_experts: int = Field(0)
# From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq
num_moe_layers: int = Field(..., ge=0)
# FIXME: might have to make this more granular
# (i.e. dense_weight_byte_size, moe_routed_weight_byte_size,
# moe_shared_weight_byte_size)
# since it can differ from byte size of other components (e.g. attn)
# and can differ even from each other.
# From BaseConfigParser, can be overridden by FfnQuantizationConfigParser
weight_byte_size: int | float = Field(..., gt=0)
@model_validator(mode="after")
def validate_moe_fields(self) -> Self:
"""Validate that MoE-related fields are properly set when num_moe_layers > 0."""
if self.num_moe_layers > 0:
assert self.num_experts, f"{self.num_experts=}"
assert self.num_experts_per_tok, f"{self.num_experts_per_tok=}"
assert self.moe_intermediate_size, f"{self.moe_intermediate_size=}"
return self
@classmethod
def component_type(cls) -> str:
return "ffn"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
FfnParallelParser(),
BaseFfnConfigParser(),
InterleaveMoeLayerStepParser(),
MoeLayerFreqParser(),
FfnQuantizationConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate flops breakdown for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
flops = {}
# Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down)
if Ld:
flops["dense_ffn"] = 2 * D * 3 * DI * T * Ld
# MoE routed experts (each token activates E experts)
if Lm and E:
flops["routed_ffn"] = 2 * D * 3 * MI * num_activated_tokens * Lm
# MoE shared experts (all S shared experts run for every token)
if Lm and S:
flops["shared_ffn"] = 2 * D * 3 * MI * S * T * Lm
return flops
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate read memory traffic for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
num_experts = self.num_experts
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
if num_experts is not None:
num_experts //= self.ffn_ep_size
read_bytes = {}
# Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation)
if Ld:
read_bytes["dense_up_gate_input"] = int(
T * D * self.activation_byte_size * Ld
)
read_bytes["dense_up_gate_weights"] = int(
2 * D * DI * self.weight_byte_size * Ld
)
read_bytes["dense_silu_input"] = int(
2 * T * DI * self.activation_byte_size * Ld
)
read_bytes["dense_down_input"] = int(
T * DI * self.activation_byte_size * Ld
)
read_bytes["dense_down_weights"] = int(D * DI * self.weight_byte_size * Ld)
if Lm:
# MoE routed expert reads
if E:
# FIXME: Assume perfect load balancing for now.
num_activated_experts = min(num_activated_tokens, num_experts)
read_bytes["routed_up_gate_input"] = int(
num_activated_tokens * D * self.activation_byte_size * Lm
)
read_bytes["routed_up_gate_weights"] = int(
2 * D * MI * num_activated_experts * self.weight_byte_size * Lm
)
read_bytes["routed_silu_input"] = int(
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
)
read_bytes["routed_down_input"] = int(
num_activated_tokens * MI * self.activation_byte_size * Lm
)
read_bytes["routed_down_weights"] = int(
D * MI * num_activated_experts * self.weight_byte_size * Lm
)
# MoE shared expert reads
if S:
read_bytes["shared_up_gate_input"] = int(
T * D * self.activation_byte_size * Lm
)
read_bytes["shared_up_gate_weights"] = int(
2 * D * MI * S * self.weight_byte_size * Lm
)
read_bytes["shared_silu_input"] = int(
2 * T * MI * S * self.activation_byte_size * Lm
)
read_bytes["shared_down_input"] = int(
T * MI * self.activation_byte_size * Lm
)
read_bytes["shared_down_weights"] = int(
D * MI * S * self.weight_byte_size * Lm
)
return read_bytes
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for FFN layers."""
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
Lm, E, MI, S = (
self.num_moe_layers,
self.num_experts_per_tok,
self.moe_intermediate_size,
self.num_shared_experts,
)
T = ctx.total_num_tokens()
Ld = L - Lm
num_activated_tokens = T * E if E else 0
if per_gpu:
Ld //= self.pp_size
Lm //= self.pp_size
DI //= self.ffn_tp_size
if MI is not None:
MI //= self.ffn_tp_size
if E:
num_activated_tokens //= self.ffn_ep_size
write_bytes = {}
# Dense FFN layers
if Ld:
write_bytes["dense_up_gate_output"] = int(
2 * T * DI * self.activation_byte_size * Ld
)
write_bytes["dense_silu_output"] = int(
T * DI * self.activation_byte_size * Ld
)
write_bytes["dense_down_output"] = int(
T * D * self.activation_byte_size * Ld
)
# MoE outputs
if Lm:
if E:
write_bytes["routed_up_gate_output"] = int(
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
)
write_bytes["routed_silu_output"] = int(
num_activated_tokens * MI * self.activation_byte_size * Lm
)
write_bytes["routed_down_output"] = int(
num_activated_tokens * D * self.activation_byte_size * Lm
)
if S:
write_bytes["shared_up_gate_output"] = int(
2 * T * S * MI * self.activation_byte_size * Lm
)
write_bytes["shared_silu_output"] = int(
T * S * MI * self.activation_byte_size * Lm
)
write_bytes["shared_down_output"] = int(
T * S * D * self.activation_byte_size * Lm
)
return write_bytes
#### Unembed ####
class UnembedMetrics(ComponentMetrics):
# From BaseConfigParser
hidden_size: int = Field(..., gt=0)
vocab_size: int = Field(..., gt=0)
weight_byte_size: int = Field(..., gt=0)
activation_byte_size: int = Field(..., gt=0)
tp_size: int
@classmethod
def component_type(cls) -> str:
return "unembed"
@classmethod
def get_parser(cls) -> ParserChain:
return ParserChain(
BaseConfigParser(),
)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate flops breakdown for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"unembed": 2 * T * D * V,
}
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate read memory traffic for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"input": T * D * self.activation_byte_size,
"weight": D * V * self.weight_byte_size,
}
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
"""Calculate write memory traffic for unembedding layer."""
V = self.vocab_size
T = ctx.total_num_tokens()
if per_gpu:
V //= self.tp_size
return {
"output": T * V * self.activation_byte_size,
}
#### ModelMetrics ####
class ModelMetrics:
def __init__(self, vllm_config: VllmConfig) -> None:
"""
Parse vllm_config to instantiate metrics for each component.
is_enabled() will return False if no component metrics could be instantiated.
"""
self.vllm_config = vllm_config
self.metrics: list[ComponentMetrics] = []
for metric_cls in ComponentMetrics.registered_metrics():
try:
metric = metric_cls.from_vllm_config(vllm_config)
self.metrics.append(metric)
logger.info(
"Instantiated ComponentMetrics [%s] with (%s)",
metric.component_type(),
str(metric),
)
except InvalidComponent as e:
logger.debug(
"Failed to instantiate %s from %s",
metric_cls.component_type(),
str(e),
)
def is_enabled(self) -> bool:
return len(self.metrics) > 0
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_num_flops(ctx, per_gpu) for metric in self.metrics)
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_read_bytes(ctx, per_gpu) for metric in self.metrics)
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
return sum(metric.get_write_bytes(ctx, per_gpu) for metric in self.metrics)
def get_num_flops_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_num_flops_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_read_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_read_bytes_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_write_bytes_breakdown(
self, ctx: ExecutionContext, per_gpu: bool = True
) -> dict[str, int]:
total = {}
for metric in self.metrics:
breakdown = metric.get_write_bytes_breakdown(ctx, per_gpu)
component = metric.component_type()
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
total.update(prefixed)
return total
def get_step_perf_stats_per_gpu(
self, scheduler_output: SchedulerOutput
) -> PerfStats:
"""
Calculate perf stats for the current step based on scheduled tokens.
"""
t0 = time.monotonic()
# Build a single batch context
ctx = ExecutionContext()
# Process new requests (these are in prefill phase)
for new_req in scheduler_output.scheduled_new_reqs:
req_id = new_req.req_id
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_tokens == 0:
continue
# For new requests, context_len = num_computed_tokens + num_tokens
# num_computed_tokens represents previously computed tokens in the sequence
context_len = new_req.num_computed_tokens + num_tokens
ctx.add(num_tokens, context_len, is_prefill=True)
# Process cached requests (continuing requests)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_tokens == 0:
continue
# For cached requests, we have the current num_computed_tokens
num_computed_tokens = cached_reqs.num_computed_tokens[i]
context_len = num_computed_tokens + num_tokens
# Cached requests are typically in decode phase (num_tokens == 1)
# unless they're doing chunked prefill (num_tokens > 1)
is_prefill = num_tokens > 1
ctx.add(num_tokens, context_len, is_prefill)
num_flops_breakdown = self.get_num_flops_breakdown(ctx, True)
read_bytes_breakdown = self.get_read_bytes_breakdown(ctx, True)
write_bytes_breakdown = self.get_write_bytes_breakdown(ctx, True)
perf_stats = PerfStats(
sum(num_flops_breakdown.values()),
sum(read_bytes_breakdown.values()),
sum(write_bytes_breakdown.values()),
)
if envs.VLLM_DEBUG_MFU_METRICS:
perf_stats.debug_stats = DebugPerfStats(
time.monotonic() - t0,
ctx.num_prefill_requests,
ctx.num_decode_requests,
asdict(ctx),
num_flops_breakdown,
read_bytes_breakdown,
write_bytes_breakdown,
)
return perf_stats
#### Logging ####
class PerfMetricsDebugLogging:
def __init__(self):
self.reset()
def reset(self):
self.total_calc_duration: float = 0.0
self.total_num_prefill_requests: int = 0
self.total_num_decode_requests: int = 0
self.total_num_batches: int = 0
self.total_context_breakdown: dict[str, int] = {}
self.total_num_flops_per_gpu_breakdown: dict[str, int] = {}
self.total_read_bytes_per_gpu_breakdown: dict[str, int] = {}
self.total_write_bytes_per_gpu_breakdown: dict[str, int] = {}
def observe(self, debug_stats: DebugPerfStats) -> None:
self.total_calc_duration += debug_stats.calc_duration
self.total_num_prefill_requests += debug_stats.num_prefill_requests
self.total_num_decode_requests += debug_stats.num_decode_requests
self.total_num_batches += 1
for dst, src in zip(
[
self.total_context_breakdown,
self.total_num_flops_per_gpu_breakdown,
self.total_read_bytes_per_gpu_breakdown,
self.total_write_bytes_per_gpu_breakdown,
],
[
debug_stats.context_breakdown,
debug_stats.num_flops_per_gpu_breakdown,
debug_stats.num_read_bytes_per_gpu_breakdown,
debug_stats.num_write_bytes_per_gpu_breakdown,
],
):
assert isinstance(src, dict)
for key, val in src.items():
dst[key] = dst.get(key, 0) + val
def log(self, log_fn, log_prefix: str, delta_time: float):
# pretty print breakdowns
total_num_flops_per_gpu_breakdown = {
k: f"{v / 1e12:.1f}TF"
for k, v in self.total_num_flops_per_gpu_breakdown.items()
}
total_read_bytes_per_gpu_breakdown = {
k: f"{v / 1e9:.1f}GB"
for k, v in self.total_read_bytes_per_gpu_breakdown.items()
}
total_write_bytes_per_gpu_breakdown = {
k: f"{v / 1e9:.1f}GB"
for k, v in self.total_write_bytes_per_gpu_breakdown.items()
}
logger.debug(
"%sMFU details: %s",
log_prefix,
json.dumps(
{
"prefill_reqs": self.total_num_prefill_requests,
"decode_reqs": self.total_num_decode_requests,
"num_batches": self.total_num_batches,
"context_breakdown": self.total_context_breakdown,
"flops_breakdown": total_num_flops_per_gpu_breakdown,
"num_read_bytes_breakdown": total_read_bytes_per_gpu_breakdown,
"num_write_bytes_breakdown": (total_write_bytes_per_gpu_breakdown),
"duration": f"{delta_time:.1f}s",
"mfu_calc_overhead": (
f"{self.total_calc_duration / delta_time:.1%}"
),
},
indent=2,
),
)
class PerfMetricsLogging:
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
self.debug_logging: PerfMetricsDebugLogging | None = None
if envs.VLLM_DEBUG_MFU_METRICS:
self.debug_logging = PerfMetricsDebugLogging()
self.reset()
def reset(self):
self.last_log_time = time.monotonic()
self.total_num_flops_per_gpu: int = 0
self.total_read_bytes_per_gpu: int = 0
self.total_write_bytes_per_gpu: int = 0
if self.debug_logging:
self.debug_logging.reset()
def observe(self, perf_stats: PerfStats) -> None:
self.total_num_flops_per_gpu += perf_stats.num_flops_per_gpu
self.total_read_bytes_per_gpu += perf_stats.num_read_bytes_per_gpu
self.total_write_bytes_per_gpu += perf_stats.num_write_bytes_per_gpu
if self.debug_logging:
assert perf_stats.debug_stats is not None
self.debug_logging.observe(perf_stats.debug_stats)
def log(self, log_fn=logger.info, log_prefix: str = "") -> None:
if not (
self.total_num_flops_per_gpu
or self.total_read_bytes_per_gpu
or self.total_write_bytes_per_gpu
):
return
now = time.monotonic()
delta_time = now - self.last_log_time
if delta_time <= 0.0:
avg_tflops_per_gpu = 0.0
avg_gbps_per_gpu = 0.0
else:
avg_tflops_per_gpu = self.total_num_flops_per_gpu / delta_time / 1e12
avg_gbps_per_gpu = (
(self.total_read_bytes_per_gpu + self.total_write_bytes_per_gpu)
/ delta_time
/ 1e9
)
log_fn(
"%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU",
log_prefix,
avg_tflops_per_gpu,
avg_gbps_per_gpu,
)
if self.debug_logging:
self.debug_logging.log(log_fn, log_prefix, delta_time)
self.reset()
## util functions
def get_required(obj: object, attr: str):
"""Get an attr from an object, or throw a InvalidComponentError if it's not set."""
if not hasattr(obj, attr):
raise InvalidComponent(f"Missing required attr {attr} in config")
return getattr(obj, attr)
def getattr_from_list(obj: object, attrs: list[str], default: object = None):
"""Try to get the first attr that exists in the object
from a list of attrs. Otherwise return None."""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)
return default
...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any ...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -186,6 +187,8 @@ class SchedulerStats: ...@@ -186,6 +187,8 @@ class SchedulerStats:
cudagraph_stats: CUDAGraphStat | None = None cudagraph_stats: CUDAGraphStat | None = None
perf_stats: PerfStats | None = None
@dataclass @dataclass
class RequestStateStats: class RequestStateStats:
......
...@@ -44,6 +44,32 @@ def _walk_json_for_additional_properties(data: object): ...@@ -44,6 +44,32 @@ def _walk_json_for_additional_properties(data: object):
_walk_json_for_additional_properties(item) _walk_json_for_additional_properties(item)
def has_guidance_unsupported_json_features(schema: dict[str, Any]) -> bool:
"""Check if JSON schema contains features unsupported by guidance/llguidance."""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
# patternProperties is not supported by llguidance
if "patternProperties" in obj:
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def process_for_additional_properties( def process_for_additional_properties(
guide_json: str | dict[str, Any], guide_json: str | dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment