Unverified Commit 3e41992f authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 91401c7a
...@@ -2403,6 +2403,29 @@ def cp_gather_cache( ...@@ -2403,6 +2403,29 @@ def cp_gather_cache(
) )
def cp_gather_and_upconvert_fp8_kv_cache(
src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
workspace_starts: torch.Tensor,
batch_size: int,
) -> None:
"""Gather and upconvert FP8 KV cache to BF16 workspace.
Args:
src_cache: FP8 KV cache [num_blocks, block_size, 656]
dst: BF16 output workspace [total_tokens, 576]
block_table: Block indices [num_reqs, max_blocks]
seq_lens: Sequence lengths [num_reqs]
workspace_starts: Workspace start offsets [num_reqs]
batch_size: Number of requests
"""
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
)
def indexer_k_quant_and_cache( def indexer_k_quant_and_cache(
k: torch.Tensor, k: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
......
...@@ -239,6 +239,7 @@ if TYPE_CHECKING: ...@@ -239,6 +239,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = "" VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
...@@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects # top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream # Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
......
...@@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import ( from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled, dbo_enabled,
dbo_maybe_run_recv_hook, dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_register_recv_hook,
dbo_yield, dbo_yield,
) )
from vllm.v1.worker.workspace import current_workspace_manager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -661,25 +661,6 @@ def _slice_scales( ...@@ -661,25 +661,6 @@ def _slice_scales(
return None return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
assert shape != ()
shape_numel = prod(shape)
if (
self.buffer is None
or self.buffer.numel() < shape_numel
or self.buffer.device != device
or self.buffer.dtype != dtype
):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
return self.buffer[:shape_numel].view(*shape)
@final @final
class FusedMoEModularKernel(torch.nn.Module): class FusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module):
objects. objects.
""" """
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__( def __init__(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
...@@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module):
assert M_full > 0 and M_chunk > 0 assert M_full > 0 and M_chunk > 0
num_chunks, _ = self._chunk_info(M_full) num_chunks, _ = self._chunk_info(M_full)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for # Force worst-case allocation in profiling run for
...@@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, expert_tokens_meta,
) )
) )
buffers.workspace13.get(
max_workspace_13, device=device, dtype=workspace_dtype current_workspace_manager().get_simultaneous(
) (max_workspace_13, workspace_dtype),
buffers.workspace2.get( (max_workspace_2, workspace_dtype),
max_workspace_2, device=device, dtype=workspace_dtype (max_fused_out_shape, out_dtype),
)
buffers.fused_out.get(
max_fused_out_shape, device=device, dtype=workspace_dtype
) )
# Get intermediate workspace shapes based off the chunked M size. # Get intermediate workspace shapes based off the chunked M size.
...@@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the # We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1. # time we need cache3, we're done with cache1.
workspace13 = buffers.workspace13.get(
workspace13_shape, device=device, dtype=workspace_dtype
)
workspace2 = buffers.workspace2.get(
workspace2_shape, device=device, dtype=workspace_dtype
)
# Construct the entire output that can then be processed in chunks. # Construct the entire output that can then be processed in chunks.
# Reuse workspace13 for the output in the non-chunked case as long # Reuse workspace13 for the output in the non-chunked case as long
# as it is large enough. This will not always be the case for standard # as it is large enough. This will not always be the case for standard
# format experts and with experts that have empty workspaces. # format experts and with experts that have empty workspaces.
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
)
fused_out = _resize_cache(workspace13, fused_out_shape) fused_out = _resize_cache(workspace13, fused_out_shape)
else: else:
fused_out = buffers.fused_out.get( workspace13, workspace2, fused_out = (
fused_out_shape, device=device, dtype=out_dtype current_workspace_manager().get_simultaneous(
(workspace13_shape, workspace_dtype),
(workspace2_shape, workspace_dtype),
(fused_out_shape, out_dtype),
)
) )
return workspace13, workspace2, fused_out return workspace13, workspace2, fused_out
......
...@@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import ( ...@@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata, DeepseekV32IndexerMetadata,
) )
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
...@@ -616,8 +617,15 @@ def sparse_attn_indexer( ...@@ -616,8 +617,15 @@ def sparse_attn_indexer(
# careful! this will be None in dummy run # careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict) # assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict): if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake( return sparse_attn_indexer_fake(
hidden_states, hidden_states,
k_cache_prefix, k_cache_prefix,
...@@ -651,17 +659,17 @@ def sparse_attn_indexer( ...@@ -651,17 +659,17 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1 topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty( k_fp8 = k_fp8_full[: chunk.total_seq_lens]
[chunk.total_seq_lens, head_dim], k_scale = k_scale_full[: chunk.total_seq_lens]
device=k.device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
ops.cp_gather_indexer_k_quant_cache( ops.cp_gather_indexer_k_quant_cache(
kv_cache, kv_cache,
k_fp8, k_fp8,
...@@ -777,15 +785,6 @@ def sparse_attn_indexer_fake( ...@@ -777,15 +785,6 @@ def sparse_attn_indexer_fake(
total_seq_lens: int, total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None, topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer return topk_indices_buffer
......
...@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
split_decodes_and_prefills, split_decodes_and_prefills,
split_prefill_chunks,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -176,40 +177,15 @@ def kv_spans_from_batches( ...@@ -176,40 +177,15 @@ def kv_spans_from_batches(
def get_max_prefill_buffer_size(vllm_config: VllmConfig): def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size. # NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
# May be tuned later. # Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
return max_model_len * 2 # The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
def split_prefill_chunks( # within the flashmla_sparse workspace.
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int # For DeepSeek-V3.2, the max_model_len is 163840.
) -> list[tuple[int, int]]: # 40 * 163840 * 132 = 865075200 bytes = 825 MB
""" return max_model_len * 40
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
...@@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None prefill_metadata = None
if num_prefills > 0: if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks( chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu, common_attn_metadata.seq_lens_cpu[num_decodes:],
self.max_prefill_buffer_size, self.max_prefill_buffer_size,
num_decodes, request_offset=num_decodes,
) )
chunks = [ chunks = [
self.build_one_prefill_chunk( self.build_one_prefill_chunk(
......
...@@ -937,6 +937,33 @@ def split_decodes_and_prefills( ...@@ -937,6 +937,33 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
"""
Split the prefill requests into chunks such that the total sequence length
of each chunk is less than or equal to the workspace size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
workspace_size: The maximum workspace size (in tokens) per chunk.
request_offset: The offset to add to the request indices.
Returns:
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
"""
chunk_bounds = []
i, n = 0, len(seq_lens_cpu)
assert torch.all(seq_lens_cpu <= workspace_size).item()
while i < n:
start, chunk_total = i, 0
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
chunk_total += s
i += 1
chunk_bounds.append((start + request_offset, i + request_offset))
return chunk_bounds
def reorder_batch_to_split_decodes_and_prefills( def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch", input_batch: "InputBatch",
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
......
...@@ -162,6 +162,7 @@ from vllm.v1.worker.ubatch_utils import ( ...@@ -162,6 +162,7 @@ from vllm.v1.worker.ubatch_utils import (
maybe_create_ubatch_slices, maybe_create_ubatch_slices,
) )
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace
from .utils import ( from .utils import (
AttentionGroup, AttentionGroup,
...@@ -297,6 +298,7 @@ class GPUModelRunner( ...@@ -297,6 +298,7 @@ class GPUModelRunner(
self.device = device self.device = device
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
cache_config.cache_dtype, self.model_config cache_config.cache_dtype, self.model_config
) )
...@@ -4597,6 +4599,10 @@ class GPUModelRunner( ...@@ -4597,6 +4599,10 @@ class GPUModelRunner(
# after here. # after here.
set_cudagraph_capturing_enabled(False) set_cudagraph_capturing_enabled(False)
# Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling.
lock_workspace()
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
......
...@@ -54,6 +54,7 @@ from vllm.v1.outputs import ( ...@@ -54,6 +54,7 @@ from vllm.v1.outputs import (
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -255,6 +256,10 @@ class Worker(WorkerBase): ...@@ -255,6 +256,10 @@ class Worker(WorkerBase):
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Initialize workspace manager
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
init_workspace_manager(self.device, num_ubatches)
# Construct the model runner # Construct the model runner
if self.use_v2_model_runner: if self.use_v2_model_runner:
from vllm.v1.worker.gpu.model_runner import ( from vllm.v1.worker.gpu.model_runner import (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import os
from itertools import accumulate
from math import prod
from typing import Optional
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.math_utils import round_up
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
return prod(shape) * dtype.itemsize
# Constants
_MB = 1024**2
_GiB = 1024**3
# Global workspace manager instance
_manager: Optional["WorkspaceManager"] = None
class WorkspaceManager:
"""Manager for workspace allocation.
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
Can be locked to prevent further growth during execution.
"""
def __init__(self, device: torch.device, num_ubatches: int | None = None):
self._device = device
# Cache num ubatches at init based on configuration (default to 1)
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
self._current_workspaces: list[torch.Tensor | None] = [None, None]
self._locked: bool = False
@staticmethod
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
"""Get size of workspace in bytes."""
if workspace is None:
return 0
return workspace.numel() * workspace.element_size()
def lock(self) -> None:
"""Lock the workspace to prevent further growth.
After locking, any attempt to allocate a larger workspace will raise
an assertion error. This ensures workspace size is fixed during execution.
"""
self._locked = True
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
[
self._workspace_size_bytes(ws) / _MB
for ws in self._current_workspaces
if ws is not None
],
)
def is_locked(self) -> bool:
"""Check if workspace is locked."""
return self._locked
def get_simultaneous(
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
) -> list[torch.Tensor]:
"""Get multiple workspace tensors simultaneously from a single allocation.
Args:
*shapes_and_dtypes: One or more (shape, dtype) tuples.
Returns:
List of tensor views into the workspace buffer, one per shape/dtype pair.
"""
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
total_bytes = sum(aligned_bytes)
# Calculate cumulative offsets using itertools.accumulate
offsets = list(accumulate([0] + aligned_bytes[:-1]))
current_workspace = self._ensure_workspace_size(total_bytes)
return [
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
.view(shapes_and_dtypes[i][1])
.reshape(shapes_and_dtypes[i][0])
for i in range(len(shapes_and_dtypes))
]
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
"""Ensure workspace is allocated and large enough, return current workspace.
Args:
required_bytes: The number of bytes required.
Returns:
The current workspace tensor.
"""
ubatch_id = dbo_current_ubatch_id()
current_workspace = self._current_workspaces[ubatch_id]
current_size = self._workspace_size_bytes(current_workspace)
if current_size < required_bytes:
def get_caller_info() -> str:
"""Find first frame outside WorkspaceManager."""
curr_frame = inspect.currentframe()
if curr_frame is None:
return "unknown"
# Walk up the stack skipping WorkspaceManager frames
curr_frame = curr_frame.f_back
while curr_frame is not None:
# TODO: This only catches instance methods (self), missing
# classmethods and staticmethods. Once Python 3.11+ is the
# minimum supported version, use co_qualname instead:
# qualname = curr_frame.f_code.co_qualname
# if qualname.startswith("WorkspaceManager."):
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
curr_frame = curr_frame.f_back
continue
filename = os.path.basename(curr_frame.f_code.co_filename)
return (
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
)
return "unknown"
if self._locked:
raise AssertionError(
f"Workspace is locked but allocation from '{get_caller_info()}' "
f"requires {required_bytes / _MB:.2f} MB, current size is "
f"{current_size / _MB:.2f} MB. "
"Workspace growth is not allowed after locking."
)
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None:
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
"%.2f MB (%d ubatches, total memory %.2f MB)",
get_caller_info(),
current_size / _MB,
required_bytes / _MB,
self._num_ubatches,
required_bytes * self._num_ubatches / _MB,
)
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
return current_workspace
def is_workspace_manager_initialized() -> bool:
"""Check if workspace manager has been initialized.
Returns:
True if workspace manager is initialized, False otherwise.
"""
return _manager is not None
def current_workspace_manager() -> "WorkspaceManager":
"""Get the current workspace manager instance.
Raises:
AssertionError: If workspace manager has not been initialized.
"""
assert _manager is not None, (
"WorkspaceManager not initialized. Call init_workspace_manager() "
"with a device before using workspace functions."
)
return _manager
def init_workspace_manager(
device: torch.device, num_ubatches: int | None = None
) -> None:
"""Initialize the workspace manager with a device.
Must be called before using any workspace functions. Typically called
from GPUModelRunner.__init__.
Args:
device: The device to allocate workspace on.
num_ubatches: Number of micro-batches. Defaults to 1.
"""
global _manager
if _manager is not None:
logger.warning(
"WorkspaceManager already initialized on device %s, "
"reinitializing on device %s",
_manager._device,
device,
)
_manager = WorkspaceManager(device, num_ubatches)
def lock_workspace() -> None:
"""Lock the workspace to prevent further growth.
After calling this function, any attempt to allocate a workspace larger
than the current size will raise an AssertionError. This ensures that
workspace size is fixed during execution and prevents unexpected memory
allocations in the hot path.
Example:
# During initialization
init_workspace_manager(device)
reserve_workspace(shape1, dtype1)
reserve_workspace(shape2, dtype2)
# Lock after warmup/profiling
lock_workspace()
# Now all get_workspace calls must fit in pre-allocated size
"""
current_workspace_manager().lock()
def reset_workspace_manager() -> None:
"""Reset the workspace manager to uninitialized state.
This is primarily intended for testing purposes to allow tests
to reinitialize the workspace manager cleanly.
"""
global _manager
_manager = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment