Unverified Commit 3e41992f authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 91401c7a
......@@ -2403,6 +2403,29 @@ def cp_gather_cache(
)
def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)
def indexer_k_quant_and_cache(
k: torch.Tensor,
kv_cache: torch.Tensor,
......
......@@ -239,6 +239,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
......@@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
......
......@@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook,
dbo_yield,
)
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__)
......@@ -661,25 +661,6 @@ def _slice_scales(
return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
assert shape != ()
shape_numel = prod(shape)
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
......@@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module):
objects.
"""
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
......@@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module):
assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
......@@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta,
)
)
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype
)
buffers.workspace2.get(
max_workspace_2, device=device, dtype=workspace_dtype
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
current_workspace_manager().get_simultaneous(
(max_workspace_13, workspace_dtype),
(max_workspace_2, workspace_dtype),
(max_fused_out_shape, out_dtype),
)
# Get intermediate workspace shapes based off the chunked M size.
......@@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = buffers.workspace13.get(
workspace13_shape, device=device, dtype=workspace_dtype
)
workspace2 = buffers.workspace2.get(
workspace2_shape, device=device, dtype=workspace_dtype
)
# Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
)
fused_out = _resize_cache(workspace13, fused_out_shape)
else:
fused_out = buffers.fused_out.get(
fused_out_shape, device=device, dtype=out_dtype
workspace13, workspace2, fused_out = (
current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
)
return workspace13, workspace2, fused_out
......
......@@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import (
......@@ -616,8 +617,15 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
......@@ -651,17 +659,17 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
......@@ -777,15 +785,6 @@ def sparse_attn_indexer_fake(
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
......
......@@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
if TYPE_CHECKING:
from vllm.model_executor.models.deepseek_v2 import Indexer
logger = init_logger(__name__)
# For FP8 sparse attention we have two impelementations:
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
# done by treating all tokens as single batch.
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
# the FP8 decode kernel for decode.
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
# prefill kernel requires padding the numer of heads to 128 while the decode does not
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
# batch mode (#2).
MIN_HEADS_FOR_BF16_PREFILL = 32
"""
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
......@@ -127,19 +145,72 @@ class FlashMLASparseMetadata:
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
fp8_extra_metadata: FP8KernelMetadata | None = None
@dataclass
class FP8SeperatePrefillDecode:
@dataclass
class Decode:
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
decode_query_len: int # needed for reshape in spec decode
@dataclass
class Prefill:
# Sequence lengths (context + query) for prefill requests
# Shape: [num_prefill_reqs]
seq_lens: torch.Tensor
# Request ID for each token: -1 for decode tokens, request index
# (0, 1, 2, ...) for prefill tokens.
# Shape: [num_actual_tokens]
request_ids: torch.Tensor
# Workspace start offsets for all prefill requests
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
# 0-indexed within each chunk. Used to map prefill tokens to workspace
# offsets in convert_logical_index_to_physical_index
workspace_starts: torch.Tensor
@dataclass
class Chunk:
"""Metadata for a chunk of prefill requests.
Prefill requests may be chunked to fit within the fixed workspace size.
"""
seq_lens: torch.Tensor
tokens_slice: slice
block_table: torch.Tensor
req_start_idx: int
workspace_starts: torch.Tensor
chunk_tot_seqlen: int
chunks: list[Chunk]
num_prefills: int = 0
num_decodes: int = 0
num_prefill_tokens: int = 0
num_decode_tokens: int = 0
decode: Decode | None = None
prefill: Prefill | None = None
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None
fp8_use_mixed_batch: bool = False
# Kernel with prefill workspace support
@triton.jit
def _convert_req_index_to_global_index_kernel(
req_id_ptr, # int32 [num_tokens]
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
# shapes (compile-time where possible)
max_num_blocks_per_req: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr, # tile width along columns
HAS_PREFILL: tl.constexpr,
# strides (in elements)
bt_stride0,
bt_stride1,
......@@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel(
# Only token == -1 should propagate as -1
is_invalid_tok = tok < 0
is_prefill = False
if HAS_PREFILL:
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
is_prefill = prefill_req_id >= 0
# Compute block id and in-block offset
block_id = tok // BLOCK_SIZE
inblock_off = tok % BLOCK_SIZE
......@@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel(
# Guard block_table access
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
out_val = tl.where(
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
)
is_invalid_tok |= ~valid_block
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
out_val = base * BLOCK_SIZE + inblock_off
# Override with prefill output if prefill is enabled
if HAS_PREFILL:
workspace_start = tl.load(
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
)
prefill_out = workspace_start + tok
out_val = tl.where(is_prefill, prefill_out, out_val)
out_val = tl.where(is_invalid_tok, -1, out_val)
# Store results
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
......@@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index(
BLOCK_SIZE: int = 64,
NUM_TOPK_TOKENS: int = 2048,
BLOCK_N: int = 128, # tile width along columns
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
):
"""
out[token_id, indice_id] =
......@@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index(
Only when token_indices[token_id, indice_id] == -1 do we output -1.
For safety, we also output -1 if the derived block_id would be
out-of-bounds.
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
instead of global cache slots. prefill_workspace_request_ids and
prefill_workspace_starts must be provided.
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
prefill request index (maps to prefill_workspace_starts)
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
starts for each prefill request
"""
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
assert token_indices.shape[1] == NUM_TOPK_TOKENS
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
)
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None
assert prefill_workspace_starts is not None
assert prefill_workspace_request_ids.dtype == torch.int32
assert prefill_workspace_starts.dtype == torch.int32
num_tokens = req_id.shape[0]
num_requests, max_num_blocks_per_req = block_table.shape
max_num_blocks_per_req = block_table.shape[1]
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
# Ensure contiguous tensors on the same device
......@@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index(
ti_stride0, ti_stride1 = token_indices_c.stride()
out_stride0, out_stride1 = out.stride()
# Prepare prefill pointers
if HAS_PREFILL_WORKSPACE:
assert prefill_workspace_request_ids is not None # for mypy
assert prefill_workspace_starts is not None # for mypy
assert prefill_workspace_request_ids.is_contiguous()
assert prefill_workspace_starts.is_contiguous()
# Exact 2D grid: tokens × column tiles
grid = (num_tokens, tiles_per_row)
......@@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index(
block_table_c,
token_indices_c,
out,
prefill_workspace_request_ids,
prefill_workspace_starts,
# shapes / constexprs
max_num_blocks_per_req,
BLOCK_SIZE,
BLOCK_N,
HAS_PREFILL_WORKSPACE,
# strides
bt_stride0,
bt_stride1,
......@@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index(
return out
@dataclass
def get_prefill_workspace_size(max_model_len: int):
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
# May be tuned later.
# Memory usage: 5 * max_model_len * 576 * 2 bytes
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
# 5 * 163840 * 576 * 2 = ~900 MB
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
return max_model_len * 5
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
......@@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
) -> None:
self.vllm_config = vllm_config
self.layer_names = layer_names
cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
# Treat requests with query length <= 1 as decodes to match the
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
props = torch.cuda.get_device_properties(device)
sm_count = props.multi_processor_count
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
self.topk_tokens_tensor = torch.tensor(
[self.topk_tokens], device=device, dtype=torch.int32
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
self.topk_tokens_tensor = torch.full(
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
)
self.max_model_len_tensor = torch.tensor(
[self.model_config.max_model_len], device=device, dtype=torch.int32
# Shape: [max_num_seqs], all elements = max_model_len
self.max_model_len_tensor = torch.full(
(max_num_seqs,),
self.model_config.max_model_len,
device=device,
dtype=torch.int32,
)
# this is ignored by `flash_mla_with_kvcache` if indices not None
self.dummy_block_table = torch.empty(
(1, 1), dtype=torch.int32, device=self.device
(max_num_seqs, 1), dtype=torch.int32, device=self.device
)
# Equation taken from FlashMLA/csrc/pybind.cpp
......@@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
dtype=torch.int32,
device=device,
)
# Sized for per-request batching (num_decodes + 1)
self.num_splits_buffer = torch.empty(
# We pack all the tokens into one batch for sparse attention.
# Otherwise, we can exceed the sm of `get_mla_metadata`.
(2,),
(max_num_seqs + 1,),
dtype=torch.int32,
device=device,
)
......@@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
device=device,
)
def build(
def _build_fp8_mixed_decode_prefill(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashMLASparseMetadata:
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
"""Build FP8 metadata treating all tokens as one mixed batch.
This matches main branch's approach and avoids the BF16 prefill kernel
which has head padding overhead when num_heads is small (high TP case).
"""
num_tokens = common_attn_metadata.num_actual_tokens
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
# Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
is_fp8_kvcache=True,
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
num_sm_parts = tile_scheduler_metadata.size(0)
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
num_splits_view = self.num_splits_buffer[:2]
num_splits_view.copy_(num_splits)
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=num_splits_view,
cache_lens=self.max_model_len_tensor[:1],
dummy_block_table=self.dummy_block_table[:1],
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata = None
if self.use_fp8_kv_cache:
return fp8_metadata
def _build_fp8_separate_prefill_decode(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode":
num_tokens = common_attn_metadata.num_actual_tokens
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode
fp8_metadata = FP8Meta(
num_decodes=num_decodes,
num_prefills=num_prefills,
num_decode_tokens=num_decode_tokens,
num_prefill_tokens=num_prefill_tokens,
)
# Extract prefill sequence lengths (context + query, not just query)
# Decode requests come first in the batch, prefill requests follow
prefill_seq_lens = None
prefill_request_id = None
prefill_workspace_starts = None
prefill_chunks = None
# For pure decode batches, prefill_request_id will be None
# For mixed batches, it will have -1 for decode and request_id for prefill
if num_prefills > 0:
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
prefill_seq_lens = seq_lens[num_decodes:]
# Build prefill_request_id: -1 for decode, request index for
# prefill. This enables a single
# convert_logical_index_to_physical_index call for all tokens
prefill_request_id = torch.full(
(num_tokens,), -1, dtype=torch.int32, device=self.device
)
# Map prefill tokens to their request IDs (0, 1, 2, ...)
for req_idx in range(num_prefills):
# Get query token range for this prefill request
global_req_idx = num_decodes + req_idx
req_query_start = query_start_loc_cpu[global_req_idx]
req_query_end = query_start_loc_cpu[global_req_idx + 1]
prefill_request_id[req_query_start:req_query_end] = req_idx
# will be adjusted by chunk loop
prefill_workspace_starts_cpu = torch.zeros(
num_prefills, dtype=torch.int32, pin_memory=True
)
prefill_workspace_starts_cpu[1:] = torch.cumsum(
prefill_seq_lens_cpu[:-1], dim=0
)
# populated by non-blocking copy after prefill_workspace_starts_cpu is
# updated by each chunk
prefill_workspace_starts = torch.empty(
num_prefills, dtype=torch.int32, device=self.device
)
# Chunk prefill requests to fit within workspace size
max_prefill_buffer_size = get_prefill_workspace_size(
self.vllm_config.model_config.max_model_len
)
chunk_bounds = split_prefill_chunks(
prefill_seq_lens_cpu, max_prefill_buffer_size
)
prefill_chunks = []
for chunk_start, chunk_end in chunk_bounds:
# Adjust workspace_starts in-place per chunk to be
# 0-indexed within each chunk
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
# Initial: workspace_starts=[0,10,25,45]
# After: workspace_starts=[0,10,0,20]
# (chunk 0 starts at 0, chunk 1 starts at 0)
offset = prefill_workspace_starts_cpu[chunk_start].item()
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
tokens_slice = slice(token_start, token_end)
# Create chunk view of gpu tensor
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
chunk_block_table = common_attn_metadata.block_table_tensor[
num_decodes + chunk_start : num_decodes + chunk_end
]
prefill_chunks.append(
FP8Meta.Prefill.Chunk(
seq_lens=chunk_seq_lens,
tokens_slice=tokens_slice,
block_table=chunk_block_table,
req_start_idx=chunk_start,
workspace_starts=chunk_workspace_starts,
chunk_tot_seqlen=chunk_tot_seqlen,
)
)
prefill_workspace_starts.copy_(
prefill_workspace_starts_cpu, non_blocking=True
)
fp8_metadata.prefill = FP8Meta.Prefill(
seq_lens=prefill_seq_lens,
request_ids=prefill_request_id,
workspace_starts=prefill_workspace_starts,
chunks=prefill_chunks,
)
if num_decodes > 0:
# Compute decode_query_len for spec decode (uniform due to require_uniform)
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor,
num_q_tokens_per_head_k=num_tokens * self.num_heads,
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
topk=self.topk_tokens,
num_heads_q=self.num_heads,
num_heads_k=1,
......@@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
self.num_splits_buffer.copy_(num_splits)
# num_splits has size [num_decodes + 1]
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
num_splits_view.copy_(num_splits)
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=self.num_splits_buffer,
# cache_lens and block_table are basically unused in sparse case
# but the decode kernel will treat -1 and indices >= cache_lens
# as invalid so we make sure cache_lens is large enough to not
# accidentally mark indices invalid, we will use -1 exclusively
# to mark invalid indices
cache_lens=self.max_model_len_tensor,
dummy_block_table=self.dummy_block_table,
num_splits=num_splits_view,
dummy_block_table=self.dummy_block_table[:num_decodes],
cache_lens=self.max_model_len_tensor[:num_decodes],
)
fp8_metadata.decode = FP8Meta.Decode(
kernel_metadata=kernel_meta,
decode_query_len=decode_query_len,
)
return fp8_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> FlashMLASparseMetadata:
cm = common_attn_metadata
num_tokens = cm.num_actual_tokens
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
req_id_per_token = np.repeat(
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
)
# Zero-fill for cudagraphs
self.req_id_per_token_buffer.fill_(0)
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
fp8_extra_metadata: (
FlashMLASparseMetadata.FP8SeperatePrefillDecode
| FlashMLASparseMetadata.FP8KernelMetadata
| None
) = None
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
if self.use_fp8_kv_cache:
if fp8_use_mixed_batch:
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
else:
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
metadata = FlashMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
block_table=common_attn_metadata.block_table_tensor,
num_reqs=cm.num_reqs,
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor,
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
fp8_extra_metadata=fp8_extra_metadata,
fp8_use_mixed_batch=fp8_use_mixed_batch,
)
return metadata
......@@ -414,12 +721,204 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
self.topk_indices_buffer = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability(100) else 64
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
assert vllm_config is not None and vllm_config.model_config is not None
prefill_workspace_size = get_prefill_workspace_size(
vllm_config.model_config.max_model_len
)
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
(self.prefill_bf16_workspace,) = (
current_workspace_manager().get_simultaneous(
(self.prefill_workspace_shape, torch.bfloat16)
)
)
def _forward_bf16_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
def _forward_fp8_kv_separate_prefill_decode(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
num_decodes = fp8_metadata.num_decodes
prefill_request_ids = None
prefill_workspace_starts = None
has_prefill_workspace = False
if fp8_metadata.prefill is not None:
prefill_request_ids = fp8_metadata.prefill.request_ids
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
has_prefill_workspace = True
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
# For BF16 cache: always use global cache slots (no workspace)
# prefill_workspace_starts has been adjusted in-place per chunk so
# prefill indices automatically come out chunk-local
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
prefill_workspace_request_ids=prefill_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
fp8_metadata = attn_metadata.fp8_extra_metadata
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
# Reshape q: (num_decode_tokens, num_heads, head_dim)
# -> (num_decodes, seq_len, num_heads, head_dim)
q = reshape_query_for_spec_decode(q, num_decodes)
seq_len = q.shape[1]
# Reshape topk_indices: (num_decode_tokens, topk)
# -> (num_decodes, seq_len, topk)
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
assert fp8_metadata.decode is not None
attn_out, _ = self._fp8_flash_mla_kernel(
q=q,
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices,
kernel_metadata=fp8_metadata.decode.kernel_metadata,
)
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
# -> (num_decode_tokens, num_heads, head_dim_v)
return reshape_attn_output_for_spec_decode(attn_out)
num_decode_tokens = fp8_metadata.num_decode_tokens
num_prefill_tokens = fp8_metadata.num_prefill_tokens
# Pure decode: direct call without allocation
if num_decode_tokens > 0 and num_prefill_tokens == 0:
assert fp8_metadata.decode is not None
attn_out = _fp8_decode(q, topk_indices)
else:
# Mixed or pure prefill: allocate output tensor
attn_out = q.new_empty(
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
if num_decode_tokens > 0:
attn_out[:num_decode_tokens] = _fp8_decode(
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
)
assert fp8_metadata.prefill is not None
for chunk in fp8_metadata.prefill.chunks:
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
ops.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
chunk_q = q[chunk.tokens_slice]
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
chunk_q,
chunk_workspace,
chunk_topk_indices_workspace,
)
return attn_out
def _forward_fp8_kv_mixed_batch(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
"""Mixed batch FP8 forward path that treats all tokens as one batch.
This is equivalent to main branch's approach and avoids the BF16
prefill kernel which has head padding overhead when num_heads is small.
Used when use_mixed_batch is True.
"""
# Convert per-request indices to global slots (decode) or workspace
# offsets (prefill).
topk_indices = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=topk_indices.shape[1],
)
assert attn_metadata.fp8_extra_metadata is not None
assert isinstance(
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
)
fp8_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = self._fp8_flash_mla_kernel(
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
kernel_metadata=fp8_metadata,
)
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
return _attn_out.squeeze(0)
def _fp8_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
) -> torch.Tensor:
return flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=kernel_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens,
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
num_splits=kernel_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices,
softmax_scale=self.softmax_scale,
)
def _bf16_flash_mla_kernel(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
......@@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output = output[:, : self.num_heads, :]
return output
def _forward_fp8_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
assert attn_metadata.fp8_extra_metadata is not None
extra_metadata = attn_metadata.fp8_extra_metadata
_attn_out, _ = flash_mla_with_kvcache(
q=q.unsqueeze(0), # unsqueeze to add batch_dim
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
block_table=extra_metadata.dummy_block_table,
head_dim_v=512,
cache_seqlens=extra_metadata.cache_lens,
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
num_splits=extra_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
softmax_scale=self.softmax_scale,
)
return _attn_out
def forward(
self,
layer: AttentionLayer,
......@@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
attn_metadata: FlashMLASparseMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
......@@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
)
if attn_metadata is None:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
......@@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
......@@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
topk_indices = self.topk_indices_buffer[:num_actual_toks]
# TODO: handle index / kv_cache correctly
topk_indices_global = triton_convert_req_index_to_global_index(
attn_metadata.req_id_per_token,
attn_metadata.block_table,
topk_indices,
BLOCK_SIZE=attn_metadata.block_size,
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
q = torch.cat([ql_nope, q_pe], dim=-1)
......@@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
scale=layer._k_scale,
)
if self.kv_cache_dtype != "fp8_ds_mla":
attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata
if not use_fp8_cache:
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
elif attn_metadata.fp8_use_mixed_batch:
attn_out = self._forward_fp8_kv_mixed_batch(
q, kv_cache, topk_indices, attn_metadata
)
else:
attn_out = self._forward_fp8_kv(
q, kv_cache, topk_indices_global, attn_metadata
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_cache, topk_indices, attn_metadata
)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
......
......@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
split_prefill_chunks,
)
logger = init_logger(__name__)
......@@ -176,40 +177,15 @@ def kv_spans_from_batches(
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# May be tuned later.
return max_model_len * 2
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
# within the flashmla_sparse workspace.
# For DeepSeek-V3.2, the max_model_len is 163840.
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
return max_model_len * 40
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
......@@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size,
num_decodes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
......
......@@ -937,6 +937,33 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
"""
Split the prefill requests into chunks such that the total sequence length
of each chunk is less than or equal to the workspace size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
workspace_size: The maximum workspace size (in tokens) per chunk.
request_offset: The offset to add to the request indices.
Returns:
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
"""
chunk_bounds = []
i, n = 0, len(seq_lens_cpu)
assert torch.all(seq_lens_cpu <= workspace_size).item()
while i < n:
start, chunk_total = i, 0
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
chunk_total += s
i += 1
chunk_bounds.append((start + request_offset, i + request_offset))
return chunk_bounds
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
......
......@@ -162,6 +162,7 @@ from vllm.v1.worker.ubatch_utils import (
maybe_create_ubatch_slices,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace
from .utils import (
AttentionGroup,
......@@ -297,6 +298,7 @@ class GPUModelRunner(
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
cache_config.cache_dtype, self.model_config
)
......@@ -4597,6 +4599,10 @@ class GPUModelRunner(
# after here.
set_cudagraph_capturing_enabled(False)
# Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling.
lock_workspace()
end_time = time.perf_counter()
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
......
......@@ -54,6 +54,7 @@ from vllm.v1.outputs import (
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
logger = init_logger(__name__)
......@@ -255,6 +256,10 @@ class Worker(WorkerBase):
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner
if self.use_v2_model_runner:
from vllm.v1.worker.gpu.model_runner import (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from itertools import accumulate
from math import prod
from typing import Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.math_utils import round_up
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
return prod(shape) * dtype.itemsize
# Constants
_MB = 1024**2
_GiB = 1024**3
# Global workspace manager instance
_manager: Optional["WorkspaceManager"] = None
class WorkspaceManager:
"""Manager for workspace allocation.
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
Can be locked to prevent further growth during execution.
"""
def __init__(self, device: torch.device, num_ubatches: int | None = None):
self._device = device
# Cache num ubatches at init based on configuration (default to 1)
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
self._current_workspaces: list[torch.Tensor | None] = [None, None]
self._locked: bool = False
@staticmethod
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
"""Get size of workspace in bytes."""
if workspace is None:
return 0
return workspace.numel() * workspace.element_size()
def lock(self) -> None:
"""Lock the workspace to prevent further growth.
After locking, any attempt to allocate a larger workspace will raise
an assertion error. This ensures workspace size is fixed during execution.
"""
self._locked = True
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
def get_simultaneous(
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
) -> list[torch.Tensor]:
"""Get multiple workspace tensors simultaneously from a single allocation.
Args:
*shapes_and_dtypes: One or more (shape, dtype) tuples.
Returns:
List of tensor views into the workspace buffer, one per shape/dtype pair.
"""
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
total_bytes = sum(aligned_bytes)
# Calculate cumulative offsets using itertools.accumulate
offsets = list(accumulate([0] + aligned_bytes[:-1]))
current_workspace = self._ensure_workspace_size(total_bytes)
return [
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
.view(shapes_and_dtypes[i][1])
.reshape(shapes_and_dtypes[i][0])
for i in range(len(shapes_and_dtypes))
]
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
"""Ensure workspace is allocated and large enough, return current workspace.
Args:
required_bytes: The number of bytes required.
Returns:
The current workspace tensor.
"""
ubatch_id = dbo_current_ubatch_id()
current_workspace = self._current_workspaces[ubatch_id]
current_size = self._workspace_size_bytes(current_workspace)
if current_size < required_bytes:
def get_caller_info() -> str:
"""Find first frame outside WorkspaceManager."""
curr_frame = inspect.currentframe()
if curr_frame is None:
return "unknown"
# Walk up the stack skipping WorkspaceManager frames
curr_frame = curr_frame.f_back
while curr_frame is not None:
# TODO: This only catches instance methods (self), missing
# classmethods and staticmethods. Once Python 3.11+ is the
# minimum supported version, use co_qualname instead:
# qualname = curr_frame.f_code.co_qualname
# if qualname.startswith("WorkspaceManager."):
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
curr_frame = curr_frame.f_back
continue
filename = os.path.basename(curr_frame.f_code.co_filename)
return (
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
)
return "unknown"
if self._locked:
raise AssertionError(
f"Workspace is locked but allocation from '{get_caller_info()}' "
f"requires {required_bytes / _MB:.2f} MB, current size is "
f"{current_size / _MB:.2f} MB. "
"Workspace growth is not allowed after locking."
)
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None:
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
"%.2f MB (%d ubatches, total memory %.2f MB)",
get_caller_info(),
current_size / _MB,
required_bytes / _MB,
self._num_ubatches,
required_bytes * self._num_ubatches / _MB,
)
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
return current_workspace
def is_workspace_manager_initialized() -> bool:
"""Check if workspace manager has been initialized.
Returns:
True if workspace manager is initialized, False otherwise.
"""
return _manager is not None
def current_workspace_manager() -> "WorkspaceManager":
"""Get the current workspace manager instance.
Raises:
AssertionError: If workspace manager has not been initialized.
"""
assert _manager is not None, (
"WorkspaceManager not initialized. Call init_workspace_manager() "
"with a device before using workspace functions."
)
return _manager
def init_workspace_manager(
device: torch.device, num_ubatches: int | None = None
) -> None:
"""Initialize the workspace manager with a device.
Must be called before using any workspace functions. Typically called
from GPUModelRunner.__init__.
Args:
device: The device to allocate workspace on.
num_ubatches: Number of micro-batches. Defaults to 1.
"""
global _manager
if _manager is not None:
logger.warning(
"WorkspaceManager already initialized on device %s, "
"reinitializing on device %s",
_manager._device,
device,
)
_manager = WorkspaceManager(device, num_ubatches)
def lock_workspace() -> None:
"""Lock the workspace to prevent further growth.
After calling this function, any attempt to allocate a workspace larger
than the current size will raise an AssertionError. This ensures that
workspace size is fixed during execution and prevents unexpected memory
allocations in the hot path.
Example:
# During initialization
init_workspace_manager(device)
reserve_workspace(shape1, dtype1)
reserve_workspace(shape2, dtype2)
# Lock after warmup/profiling
lock_workspace()
# Now all get_workspace calls must fit in pre-allocated size
"""
current_workspace_manager().lock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.
This is primarily intended for testing purposes to allow tests
to reinitialize the workspace manager cleanly.
"""
global _manager
_manager = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment