# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses import itertools from collections.abc import Callable from typing import Any import torch from vllm.config import CacheConfig from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch @triton.jit def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) src_ptr = tl.load(src_ptrs + pid) dst_ptr = tl.load(dst_ptrs + pid) size = tl.load(sizes + pid) offsets = tl.arange(0, BLOCK_SIZE) for i in range(0, size, BLOCK_SIZE): mask = (i + offsets) < size curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) data = tl.load(curr_src_ptr, mask=mask) tl.store(curr_dst_ptr, data, mask=mask) def batch_memcpy(src_ptrs, dst_ptrs, sizes): batch = src_ptrs.shape[0] assert dst_ptrs.shape[0] == batch assert sizes.shape[0] == batch grid = (batch,) BLOCK_SIZE = 1024 batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: mamba_group_ids: list[int] = [] mamba_specs: list[MambaSpec] = [] for i in range(len(kv_cache_config.kv_cache_groups)): kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec if isinstance(kv_cache_spec, MambaSpec): mamba_group_ids.append(i) mamba_specs.append(kv_cache_spec) assert len(mamba_group_ids) > 0, "no mamba layers in the model" assert all(mamba_specs[0] == spec for spec in mamba_specs) return mamba_group_ids, mamba_specs[0] @dataclasses.dataclass class MambaCopyBuffers: src_ptrs: CpuGpuBuffer dst_ptrs: CpuGpuBuffer sizes: CpuGpuBuffer offset: int = 0 @classmethod def create( cls, max_num_reqs: int, kv_cache_config: KVCacheConfig, copy_funcs: tuple[MambaStateCopyFunc, ...], make_buffer: Callable[..., CpuGpuBuffer], ) -> "MambaCopyBuffers": mamba_group_ids, _ = get_mamba_groups(kv_cache_config) entries_per_req = sum( len(kv_cache_config.kv_cache_groups[gid].layer_names) for gid in mamba_group_ids ) * len(copy_funcs) n = max_num_reqs * entries_per_req return cls( src_ptrs=make_buffer(n, dtype=torch.int64), dst_ptrs=make_buffer(n, dtype=torch.int64), sizes=make_buffer(n, dtype=torch.int32), ) def collect_mamba_copy_meta( copy_bufs: MambaCopyBuffers, kv_cache_config: KVCacheConfig, mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_group_ids: list[int], src_block_idx: int, dest_block_idx: int, accept_token_bias: int, req_state: CachedRequestState, forward_context: dict[str, Any], ) -> None: if src_block_idx == dest_block_idx and accept_token_bias == 0: return src_ptrs_np = copy_bufs.src_ptrs.np dst_ptrs_np = copy_bufs.dst_ptrs.np sizes_np = copy_bufs.sizes.np offset = copy_bufs.offset for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[torch.Tensor] = attention.kv_cache[0] for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): copy_spec = state_copy_func( state, block_ids, src_block_idx, accept_token_bias + 1 ) src_ptrs_np[offset] = copy_spec.start_addr dst_ptrs_np[offset] = state[dest_block_id].data_ptr() sizes_np[offset] = copy_spec.num_elements * state.element_size() offset += 1 copy_bufs.offset = offset def do_mamba_copy_block(copy_bufs: MambaCopyBuffers): n = copy_bufs.offset if n == 0: return batch_memcpy( copy_bufs.src_ptrs.copy_to_gpu(n), copy_bufs.dst_ptrs.copy_to_gpu(n), copy_bufs.sizes.copy_to_gpu(n), ) def preprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, cache_config: CacheConfig, mamba_state_idx: dict[str, int], input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], copy_bufs: MambaCopyBuffers, ): """ Copy the mamba state of previous step to the last (1 + num_speculative_blocks) block. """ mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) num_speculative_blocks = mamba_spec.num_speculative_blocks # TODO(Chen): we need to optimize this function a lot assert cache_config.enable_prefix_caching block_size = mamba_spec.block_size finished_req_ids = scheduler_output.finished_req_ids preempted_req_ids = scheduler_output.preempted_req_ids or set() # We need to clear mamba_state_idx for resumed requests. When requests are # force-preempted (e.g., during reset_prefix_cache / KV cache flush), # they appear in resumed_req_ids without a corresponding entry in # preempted_req_ids, leaving stale mamba_state_idx entries that can # point to block indices beyond the new (smaller) block allocation. resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): mamba_state_idx.pop(req_id, None) copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] prev_state_idx = mamba_state_idx.get(req_id) if prev_state_idx is None: # new / resumed request, no previous state # if num_computed_tokens is 0, prev_state_idx will be -1 prev_state_idx = (req_state.num_computed_tokens - 1) // block_size num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_blocks: int = ( cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size) + num_speculative_blocks ) # We always save the current running state at the last # (1 + num_speculative_blocks) block. # A corner case worth mention here: assume we have block_size = 4 and # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft # tokens [draft 1, draft 2]. Then we will have: # Block 0: [A, B, C, draft 1] # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] # Block 2: speculative block # Block 3: speculative block # And use block 1 to save the running state. curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: collect_mamba_copy_meta( copy_bufs, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, prev_state_idx, curr_state_idx, input_batch.num_accepted_tokens_cpu[i] - 1, req_state, forward_context, ) input_batch.num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], copy_bufs: MambaCopyBuffers, ): """ If a blocks is converted from partial block to full block in this step, copy the state from the block for running state to the new full block. """ num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] num_computed_tokens = req_state.num_computed_tokens num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) num_scheduled_tokens = num_scheduled_tokens_dict[req_id] num_accepted_tokens = num_accepted_tokens_cpu[i] num_tokens_running_state = ( num_computed_tokens + num_scheduled_tokens - num_draft_tokens ) new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 aligned_new_computed_tokens = ( new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size ) # TODO: how to ensure all blocks that cache_blocks called are cached here? if aligned_new_computed_tokens >= num_tokens_running_state: accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 collect_mamba_copy_meta( copy_bufs, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, src_block_idx, dest_block_idx, accept_token_bias, req_state, forward_context, ) if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(copy_bufs) def update_accepted_tokens_for_prefill_as_decode( input_batch: GPUInputBatch, prefill_as_decode_num_tokens: CpuGpuBuffer, num_accepted_tokens_gpu: torch.Tensor, scheduler_output: SchedulerOutput, decode_qlen_threshold: int | None, num_reqs: int, ): """ Adjusts num_accepted_tokens for prefill chunks processed via the decode path. This ensures subsequent iterations read from the correct sequential state slot instead of the default prefill slot 0. Not used by GDN attention, which manually separates short prefills and short decodes when building the attention metadata. """ any_is_prefill = False for i in range(num_reqs): num_computed = input_batch.num_computed_tokens_cpu[i] num_prompt = input_batch.num_prompt_tokens[i] is_prefill = num_computed < num_prompt req_id = input_batch.req_ids[i] query_len = scheduler_output.num_scheduled_tokens[req_id] if is_prefill: classified_as_decode = ( decode_qlen_threshold is not None and query_len <= decode_qlen_threshold ) num_tokens = query_len if classified_as_decode else 1 any_is_prefill = True else: num_tokens = -1 prefill_as_decode_num_tokens.np[i] = num_tokens # We can skip the GPU transfer if there aren't any values to update if any_is_prefill: prefill_as_decode_num_tokens.copy_to_gpu(num_reqs) num_accepted_tokens_gpu[:num_reqs] = torch.where( prefill_as_decode_num_tokens.gpu[:num_reqs] != -1, prefill_as_decode_num_tokens.gpu[:num_reqs], num_accepted_tokens_gpu[:num_reqs], )