Unverified Commit 4ff8c3c8 authored by Vadim Gimpelson's avatar Vadim Gimpelson Committed by GitHub
Browse files

[BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU (#35219)


Signed-off-by: default avatarVadim Gimpelson <vadim.gimpelson@gmail.com>
parent 507ddbe9
...@@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int: ...@@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int:
def round_down(x: int, y: int) -> int: def round_down(x: int, y: int) -> int:
"""Round down x to the nearest multiple of y.""" """Round down x to the nearest multiple of y."""
return (x // y) * y return (x // y) * y
def largest_power_of_2_divisor(n: int) -> int:
"""Return the largest power-of-2 that divides *n* (isolate lowest set bit)."""
return n & (-n)
...@@ -86,6 +86,26 @@ class AttentionBackend(ABC): ...@@ -86,6 +86,26 @@ class AttentionBackend(ABC):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)
@staticmethod @staticmethod
def get_kv_cache_stride_order( def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,
......
...@@ -501,6 +501,13 @@ class KVCacheManager: ...@@ -501,6 +501,13 @@ class KVCacheManager:
# Only create new KVCacheBlocks for non-empty blocks # Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return new attention block IDs for zeroing."""
ids: list[int] = []
for mgr in self.coordinator.single_type_managers:
ids.extend(mgr.take_new_block_ids())
return ids
def new_step_starts(self) -> None: def new_step_starts(self) -> None:
"""Called when a new step is started.""" """Called when a new step is started."""
self.coordinator.new_step_starts() self.coordinator.new_step_starts()
...@@ -233,6 +233,11 @@ class SchedulerOutput: ...@@ -233,6 +233,11 @@ class SchedulerOutput:
# EC Cache Connector metadata # EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None ec_connector_metadata: ECConnectorMetadata | None = None
# Block IDs freshly allocated from the pool during this scheduling step.
# The worker zeros the corresponding GPU memory before the blocks are used,
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None
@classmethod @classmethod
def make_empty(cls) -> "SchedulerOutput": def make_empty(cls) -> "SchedulerOutput":
return cls( return cls(
......
...@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import ( ...@@ -48,7 +48,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
...@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface): ...@@ -233,13 +233,8 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: self.has_mamba_layers = kv_cache_config.has_mamba_layers
return any( self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.need_mamba_block_aligned_split = ( self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align" self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
) )
...@@ -890,6 +885,12 @@ class Scheduler(SchedulerInterface): ...@@ -890,6 +885,12 @@ class Scheduler(SchedulerInterface):
self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data, scheduled_cached_reqs=cached_reqs_data,
...@@ -905,6 +906,7 @@ class Scheduler(SchedulerInterface): ...@@ -905,6 +906,7 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
) )
# NOTE(Kuntai): this function is designed for multiple purposes: # NOTE(Kuntai): this function is designed for multiple purposes:
......
...@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -55,6 +55,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
...@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC): ...@@ -208,6 +209,8 @@ class SingleTypeKVCacheManager(ABC):
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks) cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
) )
req_blocks.extend(allocated_blocks) req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in allocated_blocks)
def allocate_new_blocks( def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int self, request_id: str, num_tokens: int, num_tokens_main_model: int
...@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC): ...@@ -234,8 +237,16 @@ class SingleTypeKVCacheManager(ABC):
else: else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks) new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks return new_blocks
def take_new_block_ids(self) -> list[int]:
"""Drain and return block IDs allocated since the last call."""
ids = self.new_block_ids
self.new_block_ids = []
return ids
def cache_blocks(self, request: Request, num_tokens: int) -> None: def cache_blocks(self, request: Request, num_tokens: int) -> None:
""" """
Cache the blocks for the request. Cache the blocks for the request.
......
...@@ -489,3 +489,11 @@ class KVCacheConfig: ...@@ -489,3 +489,11 @@ class KVCacheConfig:
For models with multiple types of attention, there will be multiple groups, For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details. see `_get_kv_cache_config_uniform_page_size` for more details.
""" """
@property
def has_mamba_layers(self) -> bool:
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)
@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers
...@@ -197,6 +197,7 @@ from vllm.v1.worker.workspace import lock_workspace ...@@ -197,6 +197,7 @@ from vllm.v1.worker.workspace import lock_workspace
from .utils import ( from .utils import (
AttentionGroup, AttentionGroup,
KVBlockZeroer,
add_kv_sharing_layers_to_kv_cache_groups, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, bind_kv_cache,
prepare_kernel_block_sizes, prepare_kernel_block_sizes,
...@@ -982,6 +983,26 @@ class GPUModelRunner( ...@@ -982,6 +983,26 @@ class GPUModelRunner(
decode_threshold=self.reorder_batch_threshold, decode_threshold=self.reorder_batch_threshold,
) )
def _init_kv_zero_meta(self) -> None:
"""One-time precomputation for _zero_block_ids.
Delegates to KVBlockZeroer.init_meta with the runner's state.
Called from gpu_worker.py outside the CuMem pool context.
"""
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
self._kv_block_zeroer.init_meta(
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
kernel_block_sizes=self._kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
runner_only_attn_layers=self.runner_only_attn_layers,
static_forward_context=(self.compilation_config.static_forward_context),
)
def _zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if hasattr(self, "_kv_block_zeroer"):
self._kv_block_zeroer.zero_block_ids(block_ids)
# Note: used for model runner override. # Note: used for model runner override.
def _init_device_properties(self) -> None: def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties""" """Initialize attributes from torch.cuda.get_device_properties"""
...@@ -1018,6 +1039,11 @@ class GPUModelRunner( ...@@ -1018,6 +1039,11 @@ class GPUModelRunner(
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
# Zero GPU memory for freshly allocated cache blocks to prevent
# stale NaN/data from corrupting attention or SSM computation.
if scheduler_output.new_block_ids_to_zero:
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
# Free the cached encoder outputs. # Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes: for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None) self.encoder_cache.pop(mm_hash, None)
...@@ -6476,6 +6502,7 @@ class GPUModelRunner( ...@@ -6476,6 +6502,7 @@ class GPUModelRunner(
kernel_block_sizes = prepare_kernel_block_sizes( kernel_block_sizes = prepare_kernel_block_sizes(
kv_cache_config, self.attn_groups kv_cache_config, self.attn_groups
) )
self._kernel_block_sizes = kernel_block_sizes
# create metadata builders # create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
......
...@@ -556,6 +556,14 @@ class Worker(WorkerBase): ...@@ -556,6 +556,14 @@ class Worker(WorkerBase):
else: else:
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)") @instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> float:
warmup_sizes: list[int] = [] warmup_sizes: list[int] = []
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import product as iprod
from typing import Any
import torch import torch
...@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention ...@@ -12,6 +15,8 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import largest_power_of_2_divisor
from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
...@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import ( ...@@ -21,6 +26,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
EncoderOnlyAttentionSpec, EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig, KVCacheConfig,
KVCacheGroupSpec, KVCacheGroupSpec,
KVCacheSpec, KVCacheSpec,
...@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import ( ...@@ -31,6 +37,186 @@ from vllm.v1.kv_cache_interface import (
logger = init_logger(__name__) logger = init_logger(__name__)
@triton.jit
def _zero_kv_blocks_kernel(
seg_addrs_ptr,
block_ids_ptr,
n_blocks,
N_SEGS: tl.constexpr,
PAGE_SIZE_EL: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Zero KV cache blocks across all segments in a single launch.
Each segment is a contiguous region of one block's data. For backends
where blocks are outermost (block_dim=0) there is one segment per
buffer. For backends where K/V is outermost (block_dim=1) there are
two segments per buffer (one for K, one for V).
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
allowing segments to live in different CUDA allocations.
Programs are mapped as (block_index, seg_index, chunk_index).
"""
pid = tl.program_id(0)
chunks = PAGE_SIZE_EL // BLOCK_SIZE
work_per_block = N_SEGS * chunks
block_index = pid // work_per_block
if block_index >= n_blocks:
return
remainder = pid % work_per_block
seg_index = remainder // chunks
chunk_index = remainder % chunks
block_id = tl.load(block_ids_ptr + block_index)
seg_addr = tl.load(seg_addrs_ptr + seg_index)
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
offset = (
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
)
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
class KVBlockZeroer:
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
Call :meth:`init_meta` once after KV caches are allocated to precompute
segment addresses, then call :meth:`zero_block_ids` each step to zero
newly-allocated blocks.
"""
def __init__(self, device: torch.device, pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self._meta: tuple[torch.Tensor, int, int, int] | None = None
self._id_cap: int = 0
self._ids_pinned: torch.Tensor | None = None
self._ids_gpu: torch.Tensor | None = None
def init_meta(
self,
attn_groups_iter: Iterable["AttentionGroup"],
kernel_block_sizes: list[int],
cache_dtype: str,
runner_only_attn_layers: set[str],
static_forward_context: dict[str, Any],
) -> None:
"""One-time precomputation for zero_block_ids.
Builds absolute-address table for the Triton zeroing kernel.
Each entry is the absolute byte address of a segment start on the
GPU, so segments in different CUDA allocations work correctly.
Block IDs from the scheduler reference logical blocks whose size
may differ from the kernel block size (virtual block splitting).
PAGE_SIZE_EL accounts for this ratio so that
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
Only AttentionSpec layers are processed; Mamba layers are skipped.
"""
seen_ptrs: set[int] = set()
seg_addrs: list[int] = []
page_size_el: int | None = None
for group in attn_groups_iter:
spec = group.kv_cache_spec
if type(spec) is not FullAttentionSpec:
continue
if group.kv_cache_group_id >= len(kernel_block_sizes):
continue
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
ratio = spec.block_size // kernel_bs
block_dim = group.backend.get_kv_cache_block_dim(
kernel_bs,
spec.num_kv_heads,
spec.head_size,
cache_dtype_str=cache_dtype,
)
for layer_name in group.layer_names:
if layer_name in runner_only_attn_layers:
continue
kv = static_forward_context[layer_name].kv_cache[0]
if isinstance(kv, list):
continue
dp = kv.data_ptr()
if dp in seen_ptrs:
continue
seen_ptrs.add(dp)
el = kv.element_size()
cur_bytes = kv.stride(block_dim) * el
assert cur_bytes % 4 == 0
kernel_block_el = cur_bytes // 4
cur_page_el = kernel_block_el * ratio
if page_size_el is None:
page_size_el = cur_page_el
else:
assert page_size_el == cur_page_el, (
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
)
block_stride_bytes = cur_bytes
outer_dims = [
d
for d in range(block_dim)
if kv.stride(d) * el > block_stride_bytes
]
outer_strides = [kv.stride(d) * el for d in outer_dims]
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
seg_addrs.append(dp + off_bytes)
if not seg_addrs or page_size_el is None:
self._meta = None
return
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
self._id_cap = 8192
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
self._meta = (
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
page_size_el,
blk_size,
len(seg_addrs),
)
def zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if not block_ids or self._meta is None:
return
seg_addrs, page_size_el, blk_size, n_segs = self._meta
n_blocks = len(block_ids)
if n_blocks > self._id_cap:
self._id_cap = n_blocks * 2
self._ids_pinned = torch.empty(
self._id_cap,
dtype=torch.int64,
pin_memory=self.pin_memory,
)
self._ids_gpu = torch.empty(
self._id_cap, dtype=torch.int64, device=self.device
)
assert self._ids_pinned is not None and self._ids_gpu is not None
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
idx = self._ids_gpu[:n_blocks]
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
_zero_kv_blocks_kernel[grid](
seg_addrs,
idx,
n_blocks,
N_SEGS=n_segs,
PAGE_SIZE_EL=page_size_el,
BLOCK_SIZE=blk_size,
)
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment