Commit 864c718a authored by zhuwenwen's avatar zhuwenwen
Browse files

update v1 fa layout

parent 693d5ed4
......@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64
block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
......@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64,
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm():
......
......@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata)
elif current_platform.is_rocm():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
reshape_and_cache_cuda = ops.reshape_and_cache_cuda
from flash_attn import vllm_flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
......
......@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64 # type: ignore
block_size: BlockSize = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
import numpy as np
import torch
......@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash)
else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
vllm_flash_attn_varlen_func,
reshape_and_cache_flash)
from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
reshape_and_cache_cuda)
from vllm.config import VllmConfig, get_layers_from_vllm_config
......@@ -83,6 +82,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
if not current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -106,6 +106,35 @@ class FlashAttentionBackend(AttentionBackend):
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
else:
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod
def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
elif cache_layout == "HND":
key_stride_order = (0, 2, 1, 3)
value_stride_order = (0, 3, 1, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@dataclass
......@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
......@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
......@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
......@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl):
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False,
is_prefix_cache=True,
)
return output
......
......@@ -2494,6 +2494,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
if envs.VLLM_USE_FLASH_ATTN_PA and not kv_cache_spec.use_mla:
key_cache_shape, value_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
key_stride_order, value_stride_order = self.attn_backends[
i].get_kv_cache_stride_order()
assert len(key_stride_order) == len(
key_cache_shape)
assert len(value_stride_order) == len(
value_cache_shape)
except (AttributeError, NotImplementedError):
key_stride_order = tuple(
range(len(key_cache_shape)))
value_stride_order = tuple(
range(len(value_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape = tuple(key_cache_shape[i]
for i in key_stride_order)
value_cache_shape = tuple(value_cache_shape[i]
for i in value_stride_order)
# Maintain original KV shape view.
inv_key_order = [
key_stride_order.index(i)
for i in range(len(key_stride_order))
]
inv_value_order = [
value_stride_order.index(i)
for i in range(len(value_stride_order))
]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
total_elements = raw_tensor.numel()
key_elements = (key_cache_shape[0] * key_cache_shape[1] *
key_cache_shape[2] * key_cache_shape[3])
value_elements = (value_cache_shape[0] * value_cache_shape[1] *
value_cache_shape[2] * value_cache_shape[3])
assert total_elements == key_elements + value_elements
key_cache = raw_tensor[:key_elements].view(key_cache_shape).permute(
*inv_key_order)
value_cache = raw_tensor[key_elements:].view(value_cache_shape).permute(
*inv_value_order)
kv_caches[layer_name] = (key_cache, value_cache)
else:
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment