Commit 864c718a authored by zhuwenwen's avatar zhuwenwen
Browse files

update v1 fa layout

parent 693d5ed4
...@@ -75,7 +75,7 @@ class Attention(nn.Module): ...@@ -75,7 +75,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64 block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16
is_attention_free = False is_attention_free = False
calculate_kv_scales = False calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
...@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module): ...@@ -303,7 +303,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size, attn_backend = get_attn_backend(head_size,
dtype, dtype,
kv_cache_dtype=None, kv_cache_dtype=None,
block_size=16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64, block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16,
is_attention_free=False) is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name()) backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -15,8 +15,8 @@ if current_platform.is_cuda(): ...@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata) get_scheduler_metadata)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_cuda = ops.reshape_and_cache_cuda
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func from flash_attn import vllm_flash_attn_varlen_func
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
......
...@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"] ...@@ -1497,7 +1497,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: BlockSize = 16 if not envs.VLLM_USE_FLASH_ATTN_PA else 64 # type: ignore block_size: BlockSize = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on """Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128. sizes up to 32 are supported. On HPU devices, block size defaults to 128.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available(): ...@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata, get_scheduler_metadata,
reshape_and_cache_flash) reshape_and_cache_flash)
else: else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
vllm_flash_attn_varlen_func, reshape_and_cache_cuda)
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
...@@ -83,6 +82,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -83,6 +82,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
if not current_platform.is_rocm():
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -106,6 +106,35 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -106,6 +106,35 @@ class FlashAttentionBackend(AttentionBackend):
else: else:
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order return stride_order
else:
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod
def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
elif cache_layout == "HND":
key_stride_order = (0, 2, 1, 3)
value_stride_order = (0, 3, 1, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@dataclass @dataclass
...@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead. # performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
...@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash # and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
if not current_platform.is_rocm():
reshape_and_cache_flash( reshape_and_cache_flash(
key, key,
value, value,
...@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn) key_cache = key_cache.view(torch.float8_e4m3fn)
...@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -618,7 +662,7 @@ class FlashAttentionImpl(AttentionImpl):
# k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False, is_prefix_cache=True,
) )
return output return output
......
...@@ -2494,6 +2494,59 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -2494,6 +2494,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.page_size_bytes) kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
if envs.VLLM_USE_FLASH_ATTN_PA and not kv_cache_spec.use_mla:
key_cache_shape, value_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
key_stride_order, value_stride_order = self.attn_backends[
i].get_kv_cache_stride_order()
assert len(key_stride_order) == len(
key_cache_shape)
assert len(value_stride_order) == len(
value_cache_shape)
except (AttributeError, NotImplementedError):
key_stride_order = tuple(
range(len(key_cache_shape)))
value_stride_order = tuple(
range(len(value_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape = tuple(key_cache_shape[i]
for i in key_stride_order)
value_cache_shape = tuple(value_cache_shape[i]
for i in value_stride_order)
# Maintain original KV shape view.
inv_key_order = [
key_stride_order.index(i)
for i in range(len(key_stride_order))
]
inv_value_order = [
value_stride_order.index(i)
for i in range(len(value_stride_order))
]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
total_elements = raw_tensor.numel()
key_elements = (key_cache_shape[0] * key_cache_shape[1] *
key_cache_shape[2] * key_cache_shape[3])
value_elements = (value_cache_shape[0] * value_cache_shape[1] *
value_cache_shape[2] * value_cache_shape[3])
assert total_elements == key_elements + value_elements
key_cache = raw_tensor[:key_elements].view(key_cache_shape).permute(
*inv_key_order)
value_cache = raw_tensor[key_elements:].view(value_cache_shape).permute(
*inv_value_order)
kv_caches[layer_name] = (key_cache, value_cache)
else:
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment