Unverified Commit 3f52fa5a authored by yt0428's avatar yt0428 Committed by GitHub
Browse files

[Model] Add support for openPangu moe model (#28775)


Signed-off-by: default avataryuantao <2422264527@qq.com>
Signed-off-by: default avataryt0428 <51468697+yt0428@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 71575961
...@@ -433,6 +433,7 @@ th { ...@@ -433,6 +433,7 @@ th {
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | | `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | | | `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ | | `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ |
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ | | `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | | `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | | `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
......
...@@ -402,6 +402,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -402,6 +402,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PanguEmbeddedForCausalLM": _HfExamplesInfo( "PanguEmbeddedForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True "FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
), ),
"PanguProMoEV2ForCausalLM": _HfExamplesInfo(
"",
trust_remote_code=True,
is_available_online=False,
),
"PanguUltraMoEForCausalLM": _HfExamplesInfo( "PanguUltraMoEForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1", "FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
trust_remote_code=True, trust_remote_code=True,
......
...@@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
""" """
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASH_ATTN_DIFFKV = (
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
......
...@@ -170,6 +170,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -170,6 +170,7 @@ class Attention(nn.Module, AttentionLayerBase):
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None, kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None, attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args, **extra_impl_args,
) -> None: ) -> None:
""" """
...@@ -217,6 +218,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -217,6 +218,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None self.has_sink = extra_impl_args.get("sinks") is not None
...@@ -274,8 +276,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -274,8 +276,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
**extra_impl_args, **extra_impl_args,
) )
backend_name = self.attn_backend.get_name() self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.dtype = dtype self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
...@@ -355,6 +356,10 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -355,6 +356,10 @@ class Attention(nn.Module, AttentionLayerBase):
query, _ = self.query_quant(query, self._q_scale) query, _ = self.query_quant(query, self._q_scale)
if self.use_output: if self.use_output:
if output_shape is None:
output_shape = torch.Size(
(*query.shape[:-1], self.num_heads * self.head_size_v)
)
output_shape = output_shape if output_shape is not None else query.shape output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device) output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1] hidden_size = output_shape[-1]
...@@ -362,11 +367,11 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -362,11 +367,11 @@ class Attention(nn.Module, AttentionLayerBase):
# NOTE(woosuk): We do this outside the custom op to minimize the # NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions. # CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size) output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None: if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
...@@ -452,6 +457,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -452,6 +457,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size=block_size, block_size=block_size,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
) )
...@@ -794,6 +800,7 @@ def unified_attention_with_output( ...@@ -794,6 +800,7 @@ def unified_attention_with_output(
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> None: ) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name) attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward( self.impl.forward(
self, self,
query, query,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
SinkFullAttentionSpec,
)
logger = init_logger(__name__)
@functools.lru_cache
def create_static_sink_attention_backend(
underlying_attn_backend: type[AttentionBackend],
sink_len: int = 0,
) -> type[AttentionBackend]:
prefix = "StaticSink_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
self.sink_len = sink_len
self.block_size = vllm_config.cache_config.block_size
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
self.max_num_blocks = cdiv(
model_config.max_model_len, vllm_config.cache_config.block_size
)
self.block_table_with_sink = torch.zeros(
(
scheduler_config.max_num_seqs,
self.max_num_blocks + self.num_sink_blocks,
),
device=device,
dtype=torch.int32,
)
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
1,
self.num_sink_blocks + 1,
device=device,
dtype=torch.int32,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata.seq_lens[:] = (
common_attn_metadata.seq_lens + self.sink_len
)
common_attn_metadata.seq_lens[
common_attn_metadata.seq_lens == self.sink_len
] = 0
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
num_reqs = common_attn_metadata.num_reqs
self.block_table_with_sink[
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
:num_reqs
]
return super().build(common_prefix_len, common_attn_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=StaticSinkAttentionBuilder,
)
return attn_backend
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
"""
Attention with static sink tokens
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
sink_len: int,
attn_backend: type[AttentionBackend] | None = None,
cache_config: CacheConfig | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend,
sink_len=sink_len,
)
Attention.__init__(
self=self,
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
**kwargs,
)
CustomOp.__init__(self)
self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None
def update_sink_kv(self, sink_key, sink_value) -> None:
self.sink_key = sink_key
self.sink_value = sink_value
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
assert self.sink_key is not None and self.sink_value is not None, (
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
forward_context: ForwardContext = get_forward_context()
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
return self.forward_native(query, key, value, output_shape)
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def populate_sink_kv(self, self_kv_cache):
sink_kv_slot_mapping = torch.arange(
self.block_size,
self.sink_len + self.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
self.sink_key,
self.sink_value,
self_kv_cache,
sink_kv_slot_mapping,
self.kv_cache_dtype,
self._k_scale,
self._v_scale,
)
# We only populate the sink_key and sink_value once
self.sink_populated = True
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype,
)
def maybe_populate_sink(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0:
return
self.populate_sink_kv(self_kv_cache)
def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_populate_sink",
op_func=maybe_populate_sink,
mutates_args=["self_kv_cache"],
fake_impl=maybe_populate_sink_fake,
)
...@@ -182,3 +182,174 @@ def triton_reshape_and_cache_flash( ...@@ -182,3 +182,174 @@ def triton_reshape_and_cache_flash(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
@triton.jit
def reshape_and_cache_kernel_flash_diffkv(
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size_v]
kv_cache_ptr, # [num_blocks, block_size, num_heads, head_size + head_size_v]
slot_mapping_ptr, # [num_tokens]
k_scale, # float32
v_scale, # float32
# strides
key_stride: tl.int64,
value_stride: tl.int64,
block_stride: tl.int64,
page_stride: tl.int64,
num_heads: tl.constexpr,
head_size_k: tl.constexpr,
head_size_v: tl.constexpr,
block_size: tl.constexpr,
# FP8 flags
FP8_KV_CACHE: tl.constexpr,
# tune parameters
TILE_SIZE: tl.constexpr,
):
token_idx = tl.program_id(axis=0)
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
# Padding token that should be ignored.
return
tile_i = tl.program_id(axis=1)
tile_offs = tl.arange(0, TILE_SIZE)
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
src_key_idx = token_idx * key_stride + tile_i * head_size_k
src_value_idx = token_idx * value_stride + tile_i * head_size_v
tgt_idx = (
block_idx * block_stride
+ block_offset * page_stride
+ tile_i * (head_size_k + head_size_v)
)
# [TILE_SIZE]
key_load = tl.load(key_ptr + src_key_idx + tile_offs, mask=tile_offs < head_size_k)
if FP8_KV_CACHE:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else:
key_tile = key_load
# [TILE_SIZE]
value_load = tl.load(
value_ptr + src_value_idx + tile_offs, mask=tile_offs < head_size_v
)
if FP8_KV_CACHE:
if value_load.dtype.is_fp8():
value_tile = value_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the value_cache_ptr.dtype.element_ty
value_tile = value_load / tl.load(v_scale)
else:
value_tile = value_load
tl.store(
kv_cache_ptr + tgt_idx + tile_offs,
key_tile,
mask=tile_offs < head_size_k,
)
tl.store(
kv_cache_ptr + tgt_idx + head_size_k + tile_offs,
value_tile,
mask=tile_offs < head_size_v,
)
return
def triton_reshape_and_cache_flash_diffkv(
key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size_v]
# [num_blocks, block_size, num_heads, head_size + head_size_v]
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, # [num_tokens]
kv_cache_dtype: str, # "auto", "fp8"
k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32
):
num_heads = key.shape[1]
head_size_k = key.shape[2]
head_size_v = value.shape[2]
block_size = kv_cache.shape[1]
k_stride = key.stride()[0]
v_stride = value.stride()[0]
block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
else kv_cache.dtype
)
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache = kv_cache.view(kv_cache_torch_dtype)
assert kv_cache_dtype != torch.uint8, (
"explicit fp8 cast and store to "
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
torch.float8_e4m3fnuz,
], (
"unsupported dtype of KV cache tensor, got "
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
)
# heuristics instead of autotuning
TILE_SIZE = max(head_size_k, head_size_v)
TILE_SIZE = triton.next_power_of_2(TILE_SIZE)
if current_platform.is_rocm() or current_platform.is_xpu():
num_stages = 4
num_warps = 8
else: # cuda
num_stages = 10
num_warps = 16
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs
grid = lambda meta: (
slot_mapping.shape[0],
num_heads,
)
reshape_and_cache_kernel_flash_diffkv[grid](
key_ptr=key,
value_ptr=value,
kv_cache_ptr=kv_cache,
slot_mapping_ptr=slot_mapping,
k_scale=k_scale,
v_scale=v_scale,
# strides
key_stride=k_stride,
value_stride=v_stride,
block_stride=block_stride,
page_stride=page_stride,
num_heads=num_heads,
head_size_k=head_size_k,
head_size_v=head_size_v,
block_size=block_size,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters
TILE_SIZE=TILE_SIZE,
num_warps=num_warps,
num_stages=num_stages,
)
...@@ -29,13 +29,14 @@ import torch ...@@ -29,13 +29,14 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layer import Attention from vllm.attention.layers.static_sink_attention import StaticSinkAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_ep_group, get_ep_group,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, get_tp_group,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
...@@ -77,8 +78,11 @@ from vllm.model_executor.models.utils import ( ...@@ -77,8 +78,11 @@ from vllm.model_executor.models.utils import (
maybe_prefix, maybe_prefix,
sequence_parallel_chunk, sequence_parallel_chunk,
) )
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta from vllm.transformers_utils.config import set_default_rope_theta
from vllm.v1.attention.backends.flash_attn_diffkv import FlashAttentionDiffKVBackend
def check_ffn_act_fn(act_fn: str): def check_ffn_act_fn(act_fn: str):
...@@ -155,6 +159,14 @@ class OpenPanguMoE(nn.Module): ...@@ -155,6 +159,14 @@ class OpenPanguMoE(nn.Module):
quant_config=None, quant_config=None,
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
if (
hasattr(config, "router_enable_expert_bias")
and config.router_enable_expert_bias
):
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(self.n_routed_experts, dtype=torch.float32)
)
else:
self.gate.e_score_correction_bias = None self.gate.e_score_correction_bias = None
# Load balancing settings. # Load balancing settings.
...@@ -530,6 +542,264 @@ class OpenPanguEmbeddedAttention(nn.Module): ...@@ -530,6 +542,264 @@ class OpenPanguEmbeddedAttention(nn.Module):
) )
class OpenPanguSinkAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
if self.total_num_heads % self.tp_size != 0:
raise ValueError(
f"total_num_heads {self.total_num_heads} "
f"is not divisible by tp_size {self.tp_size}."
)
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if (
self.total_num_kv_heads > self.tp_size
and self.total_num_kv_heads % self.tp_size != 0
):
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
"Number of KV heads is greater than TP size, "
f"but total_num_kv_heads {self.total_num_kv_heads} "
f"is not divisible by tp_size {self.tp_size}."
)
elif self.total_num_kv_heads < self.tp_size:
# TODO: Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
f"Number of KV heads {self.total_num_kv_heads} is less than "
f"TP size {self.tp_size}, KV heads replication is not support yet."
)
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.qk_nope_dim = getattr(config, "qk_nope_dim", None)
self.qk_rope_dim = getattr(config, "qk_rope_dim", None)
self.v_channels = getattr(config, "v_channels", None)
self.head_dim = self.qk_rope_dim + self.qk_nope_dim
self.q_size = self.num_heads * self.head_dim
self.k_size = self.num_kv_heads * self.head_dim
self.v_size = self.num_kv_heads * self.v_channels
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.param_sink_number = getattr(config, "param_sink_number", 0)
self.param_sink_with_value = getattr(config, "param_sink_with_value", False)
self.param_sink_scalar = getattr(config, "param_sink_scalar", None)
self.param_sink_of_head_num = getattr(config, "param_sink_of_head_dim", False)
self.qkv_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
self.q_size * self.tp_size,
self.k_size * self.tp_size,
self.v_size * self.tp_size,
],
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.v_channels,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.k_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self._init_rotary_emb(
config, rope_parameters=rope_parameters, quant_config=quant_config
)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} "
"for interleaved_sliding_window is not supported."
)
else:
sliding_window = None
FlashAttentionDiffKVBackend.set_head_size_v(self.v_channels)
self.attn = StaticSinkAttention(
self.num_heads,
self.head_dim,
self.scaling,
sink_len=self.param_sink_number,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
attn_backend=FlashAttentionDiffKVBackend,
head_size_v=self.v_channels,
)
if self.param_sink_number > 0:
self.param_sink_key = torch.nn.Parameter(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.head_dim,
),
device=current_platform.current_device(),
dtype=config.torch_dtype,
)
)
set_weight_attrs(
self.param_sink_key,
{
"output_dim": 1,
"weight_loader": self.weight_loader,
},
)
if self.param_sink_with_value:
self.param_sink_value = torch.nn.Parameter(
torch.empty(
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=current_platform.current_device(),
dtype=config.torch_dtype,
)
)
set_weight_attrs(
self.param_sink_value,
{
"output_dim": 1,
"weight_loader": self.weight_loader,
},
)
else:
self.param_sink_value = torch.zeros(
(
self.param_sink_number,
self.num_kv_heads,
self.v_channels,
),
device=current_platform.current_device(),
dtype=config.torch_dtype,
)
# To enable dummy run with out weight
self.post_weight_load()
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, nn.UninitializedParameter):
final_shape = list(loaded_weight.shape)
if output_dim is not None:
assert final_shape[output_dim] % self.tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // self.tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
k = self.k_layernorm(k.view(-1, self.num_kv_heads, self.head_dim))
q, k = self.rotary_emb(positions, q, k)
q = q.view(-1, self.q_size)
k = k.view(-1, self.k_size)
attn_output = self.attn(
q,
k,
v,
output_shape=torch.Size(
[q.shape[0], q.shape[1] // self.head_dim * self.v_channels]
),
)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(
self,
config: PretrainedConfig,
rope_parameters: dict[str, Any] | None,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = False
rope_parameters = {"partial_rotary_factor": self.qk_rope_dim / self.head_dim}
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
)
def post_weight_load(self) -> None:
if hasattr(self, "k_layernorm") and self.k_layernorm is not None:
param_sink_key = self.k_layernorm(self.param_sink_key)
else:
param_sink_key = self.param_sink_key
self.attn.update_sink_kv(param_sink_key, self.param_sink_value)
class OpenPanguDecoderLayer(nn.Module): class OpenPanguDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -557,6 +827,9 @@ class OpenPanguDecoderLayer(nn.Module): ...@@ -557,6 +827,9 @@ class OpenPanguDecoderLayer(nn.Module):
and hasattr(config, "v_head_dim") and hasattr(config, "v_head_dim")
and hasattr(config, "kv_lora_rank") and hasattr(config, "kv_lora_rank")
) )
self.use_sink_attention = (
hasattr(config, "param_sink_number") and config.param_sink_number > 0
)
if self.use_mla: if self.use_mla:
self.self_attn = OpenPanguMLAAttention( self.self_attn = OpenPanguMLAAttention(
config=config, config=config,
...@@ -574,6 +847,42 @@ class OpenPanguDecoderLayer(nn.Module): ...@@ -574,6 +847,42 @@ class OpenPanguDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
elif self.use_sink_attention:
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
bias_o_proj = attention_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
raise ValueError(
f"is_causal={config.is_causal} is not support "
"for attention with sink"
)
rope_parameters = getattr(config, "rope_scaling", None)
if rope_parameters is None:
rope_parameters = {
"rope_type": "default",
"rope_theta": config.rope_theta,
}
self.self_attn = OpenPanguSinkAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_parameters=rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
else: else:
attention_bias = getattr(config, "attention_bias", False) or getattr( attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False config, "bias", False
...@@ -903,6 +1212,10 @@ class OpenPanguModel(nn.Module): ...@@ -903,6 +1212,10 @@ class OpenPanguModel(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
name = maybe_remap_kv_scale_name(name, params_dict) name = maybe_remap_kv_scale_name(name, params_dict)
if name.endswith("e_score_correction_bias"):
name = name.replace(
"e_score_correction_bias", "gate.e_score_correction_bias"
)
if name is None: if name is None:
continue continue
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
...@@ -912,8 +1225,17 @@ class OpenPanguModel(nn.Module): ...@@ -912,8 +1225,17 @@ class OpenPanguModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
self.post_weight_load()
return loaded_params return loaded_params
def post_weight_load(self) -> None:
for name, module in self.named_modules():
if module is self:
continue
if hasattr(module, "post_weight_load"):
module.post_weight_load()
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA): class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -1047,3 +1369,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel): ...@@ -1047,3 +1369,7 @@ class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel):
class PanguUltraMoEForCausalLM(OpenPanguMoEModel): class PanguUltraMoEForCausalLM(OpenPanguMoEModel):
pass pass
class PanguProMoEV2ForCausalLM(OpenPanguMoEModel):
pass
...@@ -164,6 +164,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -164,6 +164,7 @@ _TEXT_GENERATION_MODELS = {
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"OuroForCausalLM": ("ouro", "OuroForCausalLM"), "OuroForCausalLM": ("ouro", "OuroForCausalLM"),
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"), "PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
"PanguProMoEV2ForCausalLM": ("openpangu", "PanguProMoEV2ForCausalLM"),
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"), "PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from .flash_attn import (
FlashAttentionBackend,
FlashAttentionImpl,
FlashAttentionMetadata,
cascade_attention,
)
logger = init_logger(__name__)
class FlashAttentionDiffKVBackend(FlashAttentionBackend):
# Default to 128 for this backend
head_size_v: int = 128
@classmethod
def set_head_size_v(cls, head_size_v: int) -> None:
cls.head_size_v = head_size_v
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_DIFFKV"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionDiffKVImpl
# Do not modify the interface of get_kv_cache_shape,
# but consider head_size_v when returning result.
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
num_blocks,
block_size,
num_kv_heads,
head_size + FlashAttentionDiffKVBackend.head_size_v,
)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size,
# num_kv_heads, head_size + head_size_v)
return (1, 0, 2, 3, 4)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers,
# block_size, head_size + head_size_v)
return (1, 3, 0, 2, 4)
elif cache_layout == "HND":
stride_order = (0, 2, 1, 3)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size_v]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads, head_size + head_size_v]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size_v]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlashAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output.fill_(0)
attn_type = self.attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# For decoder and cross-attention, use KV cache as before
# Different head_size for K and V
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# kv_cache update for different head_size K and V
triton_reshape_and_cache_flash_diffkv(
key,
value,
kv_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype
)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if self.dcp_world_size > 1:
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
output[:num_actual_tokens],
attn_metadata,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
return output
...@@ -15,6 +15,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -15,6 +15,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec, KVCacheSpec,
MambaSpec, MambaSpec,
MLAAttentionSpec, MLAAttentionSpec,
SinkFullAttentionSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -882,6 +883,30 @@ class CrossAttentionManager(SingleTypeKVCacheManager): ...@@ -882,6 +883,30 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
raise NotImplementedError("CrossAttentionManager does not support caching") raise NotImplementedError("CrossAttentionManager does not support caching")
class SinkFullAttentionManager(FullAttentionManager):
def __init__(
self,
kv_cache_spec: SinkFullAttentionSpec,
block_pool: BlockPool,
enable_caching: bool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
):
super().__init__(
kv_cache_spec,
block_pool,
enable_caching,
kv_cache_group_id,
dcp_world_size,
pcp_world_size,
)
sink_len = kv_cache_spec.sink_len
assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0
num_sink_block = sink_len // self.block_size
self.sink_blocks = self.block_pool.free_block_queue.popleft_n(num_sink_block)
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager, MLAAttentionSpec: FullAttentionManager,
...@@ -889,6 +914,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { ...@@ -889,6 +914,7 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager, MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager, CrossAttentionSpec: CrossAttentionManager,
SinkFullAttentionSpec: SinkFullAttentionManager,
} }
......
...@@ -89,12 +89,18 @@ class FullAttentionSpec(AttentionSpec): ...@@ -89,12 +89,18 @@ class FullAttentionSpec(AttentionSpec):
In this case, we use FullAttentionSpec and record the sliding window size. In this case, we use FullAttentionSpec and record the sliding window size.
""" """
head_size_v: int | None = None
sliding_window: int | None = None sliding_window: int | None = None
""" """
Default to None for not using sliding window attention. Default to None for not using sliding window attention.
""" """
attention_chunk_size: int | None = None attention_chunk_size: int | None = None
def __post_init__(self):
if self.head_size_v is None:
object.__setattr__(self, "head_size_v", self.head_size)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
...@@ -142,6 +148,7 @@ class FullAttentionSpec(AttentionSpec): ...@@ -142,6 +148,7 @@ class FullAttentionSpec(AttentionSpec):
block_size=specs[0].block_size, block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads, num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size, head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
dtype=specs[0].dtype, dtype=specs[0].dtype,
sliding_window=cls.merge_window_sizes(sliding_window), sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
...@@ -160,6 +167,15 @@ class FullAttentionSpec(AttentionSpec): ...@@ -160,6 +167,15 @@ class FullAttentionSpec(AttentionSpec):
) )
return merged_spec return merged_spec
@property
def page_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
* (self.head_size + self.head_size_v)
* get_dtype_size(self.dtype)
)
@dataclass(frozen=True) @dataclass(frozen=True)
class MLAAttentionSpec(FullAttentionSpec): class MLAAttentionSpec(FullAttentionSpec):
...@@ -287,6 +303,56 @@ class CrossAttentionSpec(AttentionSpec): ...@@ -287,6 +303,56 @@ class CrossAttentionSpec(AttentionSpec):
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@dataclass(frozen=True)
class SinkFullAttentionSpec(FullAttentionSpec):
sink_len: int | None = None
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be FullAttentionSpec."
)
sliding_window = set(
spec.sliding_window for spec in specs if spec.sliding_window is not None
)
attention_chunk_size = set(
spec.attention_chunk_size
for spec in specs
if spec.attention_chunk_size is not None
)
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v,
sink_len=specs[0].sink_len,
dtype=specs[0].dtype,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec."
)
assert (merged_spec.sliding_window is not None) + (
merged_spec.attention_chunk_size is not None
) <= 1, (
"Model with both sliding window layers and chunked local attention "
"layers is not supported."
)
return merged_spec
@dataclass(frozen=True) @dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec): class UniformTypeKVCacheSpecs(KVCacheSpec):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment