Unverified Commit 2c11a738 authored by Congcong Chen's avatar Congcong Chen Committed by GitHub
Browse files

[Model] New model support for microsoft/Phi-4-mini-flash-reasoning (#20702)


Signed-off-by: default avatarCongcong Chen <congcongchen@microsoft.com>
parent b639327a
...@@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) { ...@@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr bool kIsVariableB = true; constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true; constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>; BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
dim3 grid(params.batch, params.dim / kNRows); constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
auto kernel = &selective_scan_fwd_kernel<Ktraits>; dim3 grid(params.batch, params.dim / kNRows);
if (kSmemSize >= 48 * 1024) { auto kernel = &selective_scan_fwd_kernel<Ktraits>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (kSmemSize >= 48 * 1024) {
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
} }
...@@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
at::Tensor z, out_z; at::Tensor z, out_z;
const bool has_z = z_.has_value(); const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") if (has_z) {
z = z_.value(); z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda()); TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){ if (varlen){
CHECK_SHAPE(z, dim, seqlen); CHECK_SHAPE(z, dim, seqlen);
} else { } else {
CHECK_SHAPE(z, batch_size, dim, seqlen); CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
} }
out_z = z;
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta; at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.scalar_type() == input_type);
...@@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ...@@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd_cuda<input_t, weight_t>(params, stream); selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
}); });
} }
...@@ -374,6 +374,7 @@ Specified using `--task generate`. ...@@ -374,6 +374,7 @@ Specified using `--task generate`.
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | | `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | | `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
......
...@@ -248,6 +248,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -248,6 +248,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True, trust_remote_code=True,
v0_only=True), v0_only=True),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True), trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
......
...@@ -103,6 +103,9 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): ...@@ -103,6 +103,9 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
_initialize_kv_caches_v1), monkeypatch.context() as m): _initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only: if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0") m.setenv("VLLM_USE_V1", "0")
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
LLM( LLM(
model_info.default, model_info.default,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
......
...@@ -458,6 +458,31 @@ def test_bind_kv_cache(): ...@@ -458,6 +458,31 @@ def test_bind_kv_cache():
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
shared_kv_cache_layers = {
'layers.2.self_attn': 'layers.1.self_attn',
'layers.3.self_attn': 'layers.0.self_attn'
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_non_attention():
from vllm.attention import Attention from vllm.attention import Attention
......
...@@ -308,7 +308,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -308,7 +308,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend.")
assert blocksparse_params is not None assert blocksparse_params is not None
assert alibi_slopes is None, ValueError( assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.") "Alibi not support for blocksparse flash attention.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
from einops import rearrange
from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.flash_attn import FlashAttentionBackend
# yapf: enable
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set,
is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
logger = init_logger(__name__)
class DifferentialFlashAttentionBackend(AttentionBackend):
accept_output_buffer = False
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2"
return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
@staticmethod
def get_name() -> str:
return "DIFFERENTIAL_FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]:
return DifferentialFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]:
return DifferentialFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]:
return DifferentialFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
class DifferentialFlashAttentionMetadata(AttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional[
"DifferentialFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
"DifferentialFlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
# Cross-layer shared attention block tables
cross_layer_shared_block_tables: Optional[torch.Tensor] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return is_all_encoder_attn_metadata_set(self)
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return is_all_cross_attn_metadata_set(self)
@property
def prefill_metadata(
self) -> Optional["DifferentialFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
cross_layer_shared_block_tables = (
None if self.cross_layer_shared_block_tables is None else
self.cross_layer_shared_block_tables[:self.num_prefills])
self._cached_prefill_metadata = DifferentialFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
def decode_metadata(
self) -> Optional["DifferentialFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
cross_layer_shared_block_tables = (
None if self.cross_layer_shared_block_tables is None else
self.cross_layer_shared_block_tables[self.num_prefills:])
self._cached_decode_metadata = DifferentialFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class DifferentialFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.cross_layer_shared_block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
# TODO: add support for chunked prefill and prefix caching.
assert not chunked_prefill_enabled, \
"chunked prefill is not supported for now"
assert not prefix_cache_hit, "prefix caching is not supported for now"
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
cross_layer_shared_block_table = []
if prefix_cache_hit:
cross_layer_shared_block_table = block_tables[seq_id]
elif block_tables is not None:
if curr_sliding_window_block == 0:
cross_layer_shared_block_table = block_tables[seq_id]
else:
cross_layer_shared_block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.cross_layer_shared_block_tables.append(
cross_layer_shared_block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(self, num_seqs: int,
block_tables: List[List[int]],
graph_block_tables) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
max_batch_size, max_blocks = graph_block_tables.shape
assert max_batch_size >= num_seqs
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
graph_block_tables = graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
self.cross_layer_shared_block_tables.extend([] *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables, self.runner.graph_block_tables)
cross_layer_shared_block_tables = \
self._get_graph_runner_block_tables(
num_seqs, self.cross_layer_shared_block_tables,
self.runner.cross_layer_shared_graph_block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
cross_layer_shared_block_tables = make_tensor_with_pad(
self.cross_layer_shared_block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return DifferentialFlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=use_captured_graph,
)
class DifferentialFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
differential_flash_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
if differential_flash_attention_config is None:
differential_flash_attention_config = {}
self.differential_flash_attention_config = \
differential_flash_attention_config
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if is_quantized_kv_cache(self.kv_cache_dtype) and (
not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError(
f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
self.lambda_full = None
self.subln = self.differential_flash_attention_config["subln"]
def split_heads(self, x):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
x = rearrange(x, "... (H two) D -> ... H two D", two=2)
x1 = x[..., 0, :]
x2 = x[..., 1, :]
return x1.contiguous(), x2.contiguous()
def split_kv_cache(self, x):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
if x.numel() == 0:
return torch.empty(0), torch.empty(0)
x1, x2 = x[0], x[1]
return x1, x2
def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor,
value: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata):
if kv_cache.numel() > 0 and key is not None and value is not None:
updated_slot_mapping = attn_metadata.slot_mapping
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(),
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def forward_generate_kv_cache(
self, query: torch.Tensor, key: Optional[torch.Tensor],
value: Optional[torch.Tensor], k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor:
head_size = self.head_size
num_heads = self.num_heads // 2
num_kv_heads = self.num_kv_heads // 2
query = query.view(-1, num_heads, head_size)
if key is not None:
assert value is not None
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
else:
assert value is None
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[
0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch"
assert value.shape[
0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch"
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens, "query shape mismatch"
assert decode_query.shape[
0] == num_decode_tokens, "decode query shape mismatch"
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if k_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0:
# normal attention
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
assert prefill_output.shape == output[:
num_prefill_tokens].shape
output[:num_prefill_tokens] = prefill_output
else:
raise Exception("prefix caching not supported")
if decode_meta := attn_metadata.decode_metadata:
block_tables_arg = decode_meta.block_tables
try:
output[num_prefill_tokens:] = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
except Exception as e:
logger.error("Error in PagedAttention.forward_decode: %s",
str(e))
raise e
# Reshape the output tensor.
return output.view(-1, num_heads, head_size)
def forward_with_kv_cache_only(
self,
query: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
):
if not attn_metadata.decode_metadata:
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
else:
block_tables_arg = attn_metadata.block_tables
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=attn_metadata.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
return output
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]
lambda_q1 = self.differential_flash_attention_config["lambda_q1"]
lambda_k1 = self.differential_flash_attention_config["lambda_k1"]
lambda_q2 = self.differential_flash_attention_config["lambda_q2"]
lambda_k2 = self.differential_flash_attention_config["lambda_k2"]
lambda_1 = torch.exp(
torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(
torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q)
self.lambda_full = lambda_1 - lambda_2 + self.lambda_init
if not self.used_shared_kv_cache: # need to generate kv-cache
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
q1, q2 = self.split_heads(q)
k1, k2 = self.split_heads(k)
v1, v2 = self.split_heads(v)
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
# Split by half along the first dimension.
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
if kv_cache1.numel() != 0:
self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata)
self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata)
key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
else:
key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
else: # re-use the kv cache, full attention
q = q.view(-1, self.num_heads, self.head_size)
q1, q2 = self.split_heads(q)
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
attn11 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
return attn_output
...@@ -295,7 +295,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -295,7 +295,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
dual_chunk_attention_config: Optional[Dict[str, Any]] = None, dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"DUAL_CHUNK_FLASH_ATTN backend.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
...@@ -622,7 +622,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -622,7 +622,8 @@ class FlashAttentionImpl(AttentionImpl):
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"FLASH_ATTN backend.")
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError(
"FlashAttention does not support block-sparse attention.") "FlashAttention does not support block-sparse attention.")
......
...@@ -1006,7 +1006,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -1006,7 +1006,8 @@ class FlashInferImpl(AttentionImpl):
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"FLASHINFER backend.")
if use_irope: if use_irope:
logger.warning_once( logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall" "Using irope in FlashInfer is not supported yet, it will fall"
......
...@@ -115,7 +115,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): ...@@ -115,7 +115,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
) -> None: ) -> None:
super(AttentionImpl, self).__init__() super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"HPU_ATTN backend.")
if use_irope: if use_irope:
logger.warning_once( logger.warning_once(
"Using irope in HPU is not supported yet, it will fall back " "Using irope in HPU is not supported yet, it will fall back "
......
...@@ -501,7 +501,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -501,7 +501,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"ROCM_FLASH backend.")
if use_irope: if use_irope:
logger.warning_once( logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it " "Using irope in ROCm Flash Attention is not supported yet, it "
......
...@@ -394,7 +394,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -394,7 +394,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"XFORMERS backend.")
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError(
"XFormers does not support block-sparse attention.") "XFormers does not support block-sparse attention.")
......
...@@ -160,10 +160,6 @@ class Attention(nn.Module): ...@@ -160,10 +160,6 @@ class Attention(nn.Module):
self.attn_type = attn_type self.attn_type = attn_type
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Cross-layer KV sharing is not supported in V0.")
validate_kv_sharing_target( validate_kv_sharing_target(
prefix, prefix,
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
......
...@@ -59,11 +59,12 @@ class LogitsProcessor(nn.Module): ...@@ -59,11 +59,12 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None, sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
prune_hidden_states: bool = True,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if self.logits_as_input: if self.logits_as_input:
logits = hidden_states logits = hidden_states
else: else:
if sampling_metadata is not None: if sampling_metadata is not None and prune_hidden_states:
hidden_states = _prune_hidden_states(hidden_states, hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata) sampling_metadata)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
logger = init_logger(__name__)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
class SambaYMLP(nn.Module):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.fc1 = nn.Linear(config.hidden_size,
2 * config.intermediate_size,
bias=False)
self.fc2 = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
y = self.fc1(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.fc2(y)
def get_virtual_engine():
forward_context: ForwardContext = get_forward_context()
return forward_context.virtual_engine
class SambaYAttention(nn.Module):
def __init__(self,
config,
layer_idx: Optional[int] = None,
yoco_cross: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class.")
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.yoco_cross = yoco_cross
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads}).")
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim)
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=True)
if yoco_cross:
self.Wqkv = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=True)
else:
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
sliding_window = config.interleaved_sliding_window[layer_idx]
if layer_idx >= config.num_hidden_layers // 2:
assert sliding_window is None, \
"sliding_window must be none for the second decoder"
else:
assert sliding_window is not None, \
"sliding_window must be set for the first decoder"
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = nn.RMSNorm(2 * self.head_dim,
eps=1e-5,
elementwise_affine=True)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
if yoco_cross:
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
kv_sharing_target_layer_name = \
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
else:
kv_sharing_target_layer_name = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.head_dim**-0.5,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**params)
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
"DIFFERENTIAL_FLASH_ATTN required"
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
def forward(
self,
hidden_states: torch.Tensor,
):
if not self.yoco_cross: # need to generate kv-cache
qkv = self.Wqkv(hidden_states)
q, k, v = qkv.split([
self.hidden_size, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
attn_output = self.attn(q, k, v)
else: # re-use the kv cache, full attention
q = self.Wqkv(hidden_states)
attn_output = self.attn(q, None, None)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class Phi4Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random", # difference
dt_scale=1.0, # difference
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
yoco_cross=False,
yoco_kv=False,
):
factory_kwargs = {"params_dtype": dtype} # difference
super().__init__()
self.yoco_cross = yoco_cross
self.yoco_kv = yoco_kv
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model /
16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.swiGluActivation = SwiGLUActivation()
if self.yoco_cross:
self.in_proj = MergedColumnParallelLinear(self.d_model,
[self.d_inner],
bias=bias,
**factory_kwargs)
self.out_proj = RowParallelLinear(self.d_inner,
self.d_model,
bias=bias,
**factory_kwargs)
return
self.conv1d = ColumnParallelLinear(
input_size=d_conv,
output_size=self.d_inner,
bias=conv_bias,
params_dtype=dtype,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.d_model,
[self.d_inner] * 2,
bias=bias,
params_dtype=dtype,
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False,
params_dtype=dtype,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.dt_rank,
self.d_inner,
bias=True,
skip_bias_add=True,
params_dtype=dtype,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self.A = nn.Parameter(
torch.empty(
self.d_inner,
self.d_state,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
bias=bias,
input_is_parallel=True,
params_dtype=dtype,
)
self.activation = "silu"
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
yoco_key_values=None) -> torch.Tensor:
if self.yoco_cross:
out = self.in_proj(hidden_states)[0]
out = self.swiGluActivation(yoco_key_values, out)
out = self.out_proj(out)
return out[0], yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states = self.in_proj(
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
# z,
None if self.yoco_kv else gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
# z
# gate.transpose(0, 1),
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
if self.yoco_kv:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values = scan_outputs.transpose(-2, -1)
scan_outputs = self.swiGluActivation(scan_outputs, gate)
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states, yoco_key_values
class SambaYDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx,
cache_config,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.mlp = SambaYMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.yoco_mb = False
self.yoco_cross = False
if layer_idx >= config.num_hidden_layers // 2:
self.yoco_mb = True
self.yoco_cross = (layer_idx
>= (config.num_hidden_layers // 2 + 2))
self.use_mamba = config.mb_per_layer > 0 and \
layer_idx % config.mb_per_layer == 0
if self.use_mamba:
factory_kwargs = {"dtype": None}
self.attn = Phi4Mamba(config.hidden_size,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
yoco_kv=self.yoco_mb,
**factory_kwargs)
else:
self.attn = SambaYAttention(config,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
cache_config=cache_config,
prefix=f"{prefix}.self_attn")
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
ssm_output: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.use_mamba:
assert mamba_cache_params is not None
else:
assert mamba_cache_params is None
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
if self.use_mamba:
attn_outputs, ssm_output = self.attn(hidden_states,
attn_metadata,
mamba_cache_params,
yoco_key_values=ssm_output)
residual = residual.to(torch.float32)
else:
attn_outputs = self.attn(hidden_states, )
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, ssm_output
class SambaYModel(nn.Module):
def __init__(self,
config,
cache_config=None,
quant_config=None,
lora_config=None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if get_pp_group().world_size != 1:
raise ValueError("Pipeline Parallel not supported")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SambaYDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
mamba_state_idx = 0
ssm_output = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i == self.config.num_hidden_layers // 2 + 2:
# profile run
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
cache_layer = self.layers[kv_cache_idx]
kv_cache = cache_layer.attn.attn.kv_cache
if kv_cache[0].numel() == 0:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)
if layer.use_mamba:
if i < self.config.num_hidden_layers // 2 or \
not layer.yoco_cross:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx)
mamba_state_idx += 1
else:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx - 1)
hidden_states, ssm_output = layer(hidden_states,
positions,
attn_metadata,
mamba_cache,
ssm_output=ssm_output)
else:
hidden_states, ssm_output = layer(
hidden_states,
positions,
attn_metadata,
None, # mamba_cache_params
ssm_output=ssm_output)
hidden_states = self.final_layernorm(
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
return hidden_states
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert not cache_config.enable_prefix_caching, \
"Phi4flash currently does not support prefix caching"
assert not scheduler_config.chunked_prefill_enabled, \
"Phi4Flash currently does not support prefix caching"
super().__init__()
self.config = config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = SambaYModel(config,
cache_config=cache_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
self.embedding_bias = None
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logits_as_input=False)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states = self.model(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def _get_mamba_cache_shape(
self
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
mamba_expand = self.config.mamba_expand # 2
mamba_d_conv = self.config.mamba_d_conv # 4
mamba_d_state = self.config.mamba_d_state # 16
conv_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_conv - 1,
)
temporal_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# If the shape is the same, it means that we have already
# prune hidden states manually.
prune_hidden_states = hidden_states.size(
0) != sampling_metadata.selected_token_indices.size(0)
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
sampling_metadata,
self.embedding_bias,
prune_hidden_states=prune_hidden_states)
return processed_logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
):
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
if "A_log" in name:
name = name.replace("A_log", "A")
weight = -torch.exp(weight.float())
if "inner_cross_attn." in name:
name = name.replace("inner_cross_attn.", "")
adjusted_weights[name] = weight
adjusted_weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"]
loaded_params: set[str] = set()
for name, param in self.named_parameters():
weight = adjusted_weights.get(name)
if weight is not None and weight.shape != param.shape:
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
param.shape)
loaded_params.add(name)
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
return loaded_params
...@@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......
...@@ -316,6 +316,10 @@ class CudaPlatformBase(Platform): ...@@ -316,6 +316,10 @@ class CudaPlatformBase(Platform):
logger.info("Using DualChunkFlashAttention backend.") logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn." return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend") "DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
pass pass
elif selected_backend: elif selected_backend:
......
...@@ -60,6 +60,7 @@ class _Backend(enum.Enum): ...@@ -60,6 +60,7 @@ class _Backend(enum.Enum):
IPEX = enum.auto() IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto() NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto()
......
...@@ -2888,8 +2888,9 @@ def get_mp_context(): ...@@ -2888,8 +2888,9 @@ def get_mp_context():
def bind_kv_cache( def bind_kv_cache(
ctx: dict[str, Any], ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None
) -> None: ) -> None:
# Bind the kv_cache tensor to Attention modules, similar to # Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
...@@ -2901,12 +2902,17 @@ def bind_kv_cache( ...@@ -2901,12 +2902,17 @@ def bind_kv_cache(
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn # attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache # and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor # tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if shared_kv_cache_layers is None:
shared_kv_cache_layers = {}
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [ layer_need_kv_cache = [
layer_name for layer_name in ctx layer_name for layer_name in ctx
if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \
and ctx[layer_name].kv_sharing_target_layer_name is None
] ]
layer_index_sorted = sorted( layer_index_sorted = sorted(
set( set(
...@@ -2919,6 +2925,12 @@ def bind_kv_cache( ...@@ -2919,6 +2925,12 @@ def bind_kv_cache(
assert len(forward_ctx.kv_cache) == len(kv_cache) assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache): for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
if shared_kv_cache_layers is not None:
for layer_name, target_layer_name in shared_kv_cache_layers.items():
assert extract_layer_index(target_layer_name) < \
extract_layer_index(layer_name), \
"v0 doesn't support interleaving kv sharing"
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment