"vscode:/vscode.git/clone" did not exist on "87ef4618428fe2c8f756a80c271857fa6ae2623a"
Unverified Commit 34916ae3 authored by Asaf Joseph Gardin's avatar Asaf Joseph Gardin Committed by GitHub
Browse files

[Mamba] - Consolidate Mambas Attention Logic (#28133)

parent 0736f901
...@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp): ...@@ -118,6 +118,7 @@ class ShortConv(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p
BCx, _ = self.in_proj(hidden_states) BCx, _ = self.in_proj(hidden_states)
...@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp): ...@@ -165,11 +166,6 @@ class ShortConv(MambaBase, CustomOp):
[num_decodes, num_prefills], [num_decodes, num_prefills],
dim=0, dim=0,
) )
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = [] conv_output_list = []
......
...@@ -3,17 +3,11 @@ ...@@ -3,17 +3,11 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.mamba_attn import (
from vllm.config import VllmConfig BaseMambaAttentionMetadata,
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder BaseMambaAttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class Mamba1AttentionBackend(AttentionBackend): class Mamba1AttentionBackend(AttentionBackend):
...@@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend): ...@@ -23,137 +17,12 @@ class Mamba1AttentionBackend(AttentionBackend):
@dataclass @dataclass
class Mamba1AttentionMetadata: class Mamba1AttentionMetadata(BaseMambaAttentionMetadata):
query_start_loc_p: torch.Tensor pass
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
class Mamba1AttentionMetadataBuilder( class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
): ):
def __init__( metadata_cls = Mamba1AttentionMetadata
self, supports_update_block_table: bool = False
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
# We should consolidate this code
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass, replace
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend): ...@@ -94,48 +94,26 @@ class Mamba2AttentionBackend(AttentionBackend):
@dataclass @dataclass
class Mamba2AttentionMetadata: class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
num_prefills: int prep_initial_states: bool = False
num_prefill_tokens: int chunk_size: int = 0
num_decodes: int
num_decode_tokens: int
query_start_loc_p: torch.Tensor
seq_lens: torch.Tensor
prep_initial_states: bool
chunk_size: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
seq_idx_p: torch.Tensor | None
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offests into the varlen sequence dimension. It is defined # each chunk, its offests into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1]. # cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the # last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch. # index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None last_chunk_indices_p: torch.Tensor | None = None
state_indices_tensor: torch.Tensor # shape: [batch,]
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
num_computed_tokens_p: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
): ):
supports_update_block_table: bool = True metadata_cls = Mamba2AttentionMetadata
def __init__( def __init__(
self, self,
...@@ -150,109 +128,31 @@ class Mamba2AttentionMetadataBuilder( ...@@ -150,109 +128,31 @@ class Mamba2AttentionMetadataBuilder(
"chunk_size needs to be set in the model config for Mamba2 models" "chunk_size needs to be set in the model config for Mamba2 models"
) )
def build( def _compute_chunk_metadata(
self, self,
common_prefix_len: int, num_prefills: int,
common_attn_metadata: CommonAttentionMetadata, num_computed_tokens_p_cpu: torch.Tensor,
fast_build: bool = False, query_start_loc_p_cpu: torch.Tensor,
) -> Mamba2AttentionMetadata: ) -> tuple[list[int], list[int], list[int]]:
num_reqs = common_attn_metadata.num_reqs """
seq_lens = common_attn_metadata.seq_lens Compute chunk-specific metadata for Mamba2.
query_start_loc_p = None The code below carefully constructs the chunks such that:
seq_idx_p = None 1. Chunks contain tokens from a *single* sequence only.
cu_chunk_seqlen_p = None 2. For every sequence, we are guaranteed that we can
last_chunk_indices_p = None retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
# Need flags to indicate if there are initial states Constraint (2) dramatically simplifies the implementation
has_initial_states_p = None of prefix caching for mamba2 (wip). We need to take care
prep_initial_states = False of the interaction with chunked prefill in order to
satisfy constraint (2).
# for causal_conv1d """
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# Additional cache-related varaiables:
block_idx_last_scheduled_token = None
block_idx_last_computed_token = None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Compute seq_idx for prefill only
if num_prefills > 0:
# [batch,]
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
# The code below carefully constructs the chunks such that:
# 1. Chunks contain tokens from a *single* sequence only.
# 2. For every sequence, we are guaranteed that we can
# retrieve the mamba state *every* chunk_size tokens.
# Constraint (1) dramatically simplifies the mamba2 kernels.
# Constraint (2) dramatically simplifies the implementation
# of prefix caching for mamba2 (wip). We need to take care
# of the interaction with chunked prefill in order to
# satisfy constraint (2).
# TODO (tdoublep): This code could probably be optimized. # TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = [] cu_chunk_seqlen = []
seq_idx = [] seq_idx = []
last_chunk_indices = [] last_chunk_indices = []
seqlen_pos = 0 seqlen_pos = 0
for req_idx in range(num_prefills): for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item() this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = ( this_new_tokens = (
...@@ -288,88 +188,68 @@ class Mamba2AttentionMetadataBuilder( ...@@ -288,88 +188,68 @@ class Mamba2AttentionMetadataBuilder(
cu_chunk_seqlen.append(seqlen_pos) cu_chunk_seqlen.append(seqlen_pos)
seq_idx_p = torch.as_tensor( return cu_chunk_seqlen, seq_idx, last_chunk_indices
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
)
nums_dict, batch_ptr, token_chunk_offset_ptr = ( def build(
compute_causal_conv1d_metadata(query_start_loc_p) self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba2AttentionMetadata:
common = self._compute_common_metadata(common_attn_metadata)
seq_idx_p = None
cu_chunk_seqlen_p = None
last_chunk_indices_p = None
prep_initial_states = False
# Compute seq_idx for prefill only
if common.num_prefills > 0:
prep_initial_states = (
torch.any(common.has_initial_states_p).item()
if common.has_initial_states_p is not None
else False
) )
elif ( num_reqs = common.num_reqs
num_decodes <= self.decode_cudagraph_max_bs num_prefills = common.num_prefills
and self.compilation_config.cudagraph_mode.has_full_cudagraphs() num_decode_tokens = common.num_decode_tokens
):
self.state_indices_tensor[:num_decodes].copy_( num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
state_indices_tensor, non_blocking=True num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
) )
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
if self.vllm_config.cache_config.enable_prefix_caching: cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
self.block_idx_last_scheduled_token[:num_decodes].copy_( num_prefills,
block_idx_last_scheduled_token, non_blocking=True num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
) )
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_( seq_idx_p = torch.as_tensor(
block_idx_last_computed_token, non_blocking=True seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
) )
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
attn_metadata = Mamba2AttentionMetadata( return replace(
num_prefills=num_prefills, common,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
seq_lens=seq_lens,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
has_initial_states_p=has_initial_states_p,
seq_idx_p=seq_idx_p, seq_idx_p=seq_idx_p,
state_indices_tensor=state_indices_tensor,
cu_chunk_seqlen_p=cu_chunk_seqlen_p, cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p, last_chunk_indices_p=last_chunk_indices_p,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
) )
return attn_metadata
def update_block_table(
self,
metadata: Mamba2AttentionMetadata,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> Mamba2AttentionMetadata:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc import abc
import copy
from dataclasses import dataclass
from typing import ClassVar, TypeVar from typing import ClassVar, TypeVar
import torch import torch
...@@ -9,20 +11,52 @@ import torch ...@@ -9,20 +11,52 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M") M = TypeVar("M", bound="BaseMambaAttentionMetadata")
@dataclass
class BaseMambaAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_reqs: int
# The following tensors only contain prefill requests and will be None if
# the batch has no prefill request.
has_initial_states_p: torch.Tensor | None
query_start_loc_p: torch.Tensor | None
num_computed_tokens_p: torch.Tensor | None
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
metadata_cls: type[M]
reorder_batch_threshold: int = 1 reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
) )
supports_update_block_table: bool = True
def __init__( def __init__(
self, self,
...@@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -87,6 +121,18 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
return self.build(0, m) return self.build(0, m)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Default build implementation for Mamba-like attention backends.
Subclasses (e.g., Mamba2) can override to add additional metadata.
"""
return self._compute_common_metadata(common_attn_metadata)
def _compute_prefix_caching_block_indices( def _compute_prefix_caching_block_indices(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
...@@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): ...@@ -115,3 +161,147 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_first_scheduled_token, block_idx_first_scheduled_token,
block_idx_last_scheduled_token, block_idx_last_scheduled_token,
) )
def _compute_common_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
) -> M:
"""
Compute metadata common to both Mamba1 and Mamba2.
"""
num_reqs = common_attn_metadata.num_reqs
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
# Need flags to indicate if there are initial states
has_initial_states_p = None
query_start_loc_p = None
num_computed_tokens = None
num_computed_tokens_p = None
# for prefix caching
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
block_idx_last_computed_token = None
block_idx_last_scheduled_token = None
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
# Return a tensor of shape (#requests, #max blocks)
state_indices_tensor = common_attn_metadata.block_table_tensor
# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
(
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
) = self._compute_prefix_caching_block_indices(
common_attn_metadata, mamba_block_size
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
if num_prefills > 0:
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(
common_attn_metadata.query_start_loc.device
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
if self.vllm_config.cache_config.enable_prefix_caching:
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
]
assert block_idx_first_scheduled_token is not None
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
num_reqs - num_prefills : num_reqs
]
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_decode_tokens
]
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_decode_tokens
]
return self.metadata_cls(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
state_indices_tensor=state_indices_tensor,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer
if (
metadata.num_prefills == 0
and num_reqs <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
persistent_state_indices_t = self.state_indices_tensor[:num_reqs]
persistent_state_indices_t.copy_(state_indices_t, non_blocking=True)
state_indices_t = persistent_state_indices_t
new_metadata.state_indices_tensor = state_indices_t
return new_metadata
...@@ -2,15 +2,10 @@ ...@@ -2,15 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.mamba_attn import (
from vllm.v1.attention.backends.utils import ( BaseMambaAttentionMetadata,
PAD_SLOT_ID, BaseMambaAttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
) )
...@@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend): ...@@ -21,84 +16,11 @@ class ShortConvAttentionBackend(AttentionBackend):
@dataclass @dataclass
class ShortConvAttentionMetadata: class ShortConvAttentionMetadata(BaseMambaAttentionMetadata):
num_prefills: int pass
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states_p: torch.Tensor | None
# For causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
token_chunk_offset_ptr: torch.Tensor | None = None
class ShortConvAttentionMetadataBuilder( class ShortConvAttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata] BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
): ):
def build( metadata_cls = ShortConvAttentionMetadata
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> ShortConvAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
has_initial_states_p = None
if num_prefills > 0:
has_initial_states_cpu = (
common_attn_metadata.num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
> 0
)
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
- num_decode_tokens
)
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(query_start_loc_p)
)
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata(
query_start_loc=query_start_loc,
state_indices_tensor=state_indices_tensor,
has_initial_states_p=has_initial_states_p,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment