Unverified Commit 5206e5e2 authored by Harry Huang's avatar Harry Huang Committed by GitHub
Browse files

[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)


Signed-off-by: default avatarhuanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
parent fec9da0a
...@@ -152,7 +152,11 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer ...@@ -152,7 +152,11 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import (
check_attention_cp_compatibility,
get_total_cp_world_size,
)
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -688,6 +692,7 @@ class GPUModelRunner( ...@@ -688,6 +692,7 @@ class GPUModelRunner(
# Ephemeral state transferred between execute_model() and sample_tokens(). # Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None self.kv_connector_output: KVConnectorOutput | None = None
self.mamba_state_idx: dict[str, int] = {}
self.layerwise_nvtx_hooks_registered = False self.layerwise_nvtx_hooks_registered = False
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
...@@ -1075,7 +1080,7 @@ class GPUModelRunner( ...@@ -1075,7 +1080,7 @@ class GPUModelRunner(
self.input_batch.refresh_metadata() self.input_batch.refresh_metadata()
def _update_states_after_model_execute( def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None: ) -> None:
"""Update the cached states after model execution. """Update the cached states after model execution.
...@@ -1111,6 +1116,16 @@ class GPUModelRunner( ...@@ -1111,6 +1116,16 @@ class GPUModelRunner(
) )
for i, num_tokens in enumerate(num_accepted_tokens): for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.postprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.input_batch,
self.requests,
self.mamba_state_idx,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
)
def _init_mrope_positions(self, req_state: CachedRequestState): def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model() model = self.get_model()
...@@ -2751,7 +2766,6 @@ class GPUModelRunner( ...@@ -2751,7 +2766,6 @@ class GPUModelRunner(
logits, logits,
sampling_metadata, sampling_metadata,
) )
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
return sampler_output return sampler_output
def _bookkeeping_sync( def _bookkeeping_sync(
...@@ -3237,6 +3251,18 @@ class GPUModelRunner( ...@@ -3237,6 +3251,18 @@ class GPUModelRunner(
pad_attn = cudagraph_mode == CUDAGraphMode.FULL pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.cache_config,
self.mamba_state_idx,
self.input_batch,
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
...@@ -3423,6 +3449,10 @@ class GPUModelRunner( ...@@ -3423,6 +3449,10 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"): with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata) sampler_output = self._sample(logits, spec_decode_metadata)
self._update_states_after_model_execute(
sampler_output.sampled_token_ids, scheduler_output
)
self._draft_token_ids = None self._draft_token_ids = None
self._draft_token_req_ids = None self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None self.input_batch.prev_sampled_token_ids = None
...@@ -5322,6 +5352,24 @@ class GPUModelRunner( ...@@ -5322,6 +5352,24 @@ class GPUModelRunner(
for kv_cache_group in kv_cache_config.kv_cache_groups for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
] ]
max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
max_num_blocks_per_req = cdiv(
max_model_len, block_sizes[i] * get_total_cp_world_size()
)
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req
if self.cache_config.enable_prefix_caching
else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks_per_req = max(
max_num_blocks_per_req, mamba_blocks_per_req
)
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size self.cache_config.block_size
...@@ -5333,18 +5381,18 @@ class GPUModelRunner( ...@@ -5333,18 +5381,18 @@ class GPUModelRunner(
) )
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=max(self.max_model_len, self.max_encoder_len), max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes, block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes, kernel_block_sizes=kernel_block_sizes,
max_num_blocks_per_req=max_num_blocks,
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
num_speculative_tokens=self.num_spec_tokens,
) )
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Any
import torch
import triton
import triton.language as tl
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)
offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
data = tl.load(curr_src_ptr, mask=mask)
tl.store(curr_dst_ptr, data, mask=mask)
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch
grid = (batch,)
BLOCK_SIZE = 1024
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
mamba_group_ids: list[int] = []
mamba_specs: list[MambaSpec] = []
for i in range(len(kv_cache_config.kv_cache_groups)):
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
if isinstance(kv_cache_spec, MambaSpec):
mamba_group_ids.append(i)
mamba_specs.append(kv_cache_spec)
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
assert all(mamba_specs[0] == spec for spec in mamba_specs)
return mamba_group_ids, mamba_specs[0]
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
src_block_idx: int,
dest_block_idx: int,
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
def preprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block.
"""
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching
block_size = mamba_spec.block_size
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids or set()
for req_id in itertools.chain(finished_req_ids, preempted_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
if prev_state_idx is None:
# new / resumed request, no previous state
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
num_blocks = len(req_state.block_ids[mamba_group_ids[0]])
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# A corner case worth mention here: assume we have block_size = 4 and
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
# tokens [draft 1, draft 2]. Then we will have:
# Block 0: [A, B, C, draft 1]
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
# Block 2: speculative block
# Block 3: speculative block
# And use block 1 to save the running state.
curr_state_idx = num_blocks - 1 - num_speculative_blocks
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
prev_state_idx,
curr_state_idx,
input_batch.num_accepted_tokens_cpu[i] - 1,
req_state,
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
If a blocks is converted from partial block to full block in this step, copy the
state from the block for running state to the new full block.
"""
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
num_accepted_tokens = num_accepted_tokens_cpu[i]
num_tokens_running_state = (
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
)
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
aligned_new_computed_tokens = (
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
)
# TODO: how to ensure all blocks that cache_blocks called are cached here?
if aligned_new_computed_tokens >= num_tokens_running_state:
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
src_block_idx,
dest_block_idx,
accept_token_bias,
req_state,
forward_context,
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment