Commit 981f8370 authored by zhuwenwen's avatar zhuwenwen
Browse files

update v1 fa layout and set v1 attention use fa

parent c0f0b209
......@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata)
elif current_platform.is_rocm():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
reshape_and_cache_cuda = ops.reshape_and_cache_cuda
from flash_attn import vllm_flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
......
......@@ -289,7 +289,7 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1
if envs.is_set("VLLM_USE_FLASH_ATTN_PA") and envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import torch
......@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash)
else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
vllm_flash_attn_varlen_func,
reshape_and_cache_flash)
from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
reshape_and_cache_cuda)
from vllm.config import VllmConfig, get_layers_from_vllm_config
......@@ -84,6 +83,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
if not current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -107,6 +107,35 @@ class FlashAttentionBackend(AttentionBackend):
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
else:
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod
def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND":
key_stride_order = (0, 2, 1, 3)
value_stride_order = (0, 2, 1, 3)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
......@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata, layer)
# For decoder and cross-attention, use KV cache as before
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
......@@ -522,6 +554,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
......@@ -532,6 +565,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
......@@ -582,7 +626,7 @@ class FlashAttentionImpl(AttentionImpl):
else:
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
......@@ -607,7 +651,7 @@ class FlashAttentionImpl(AttentionImpl):
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
is_prefix_cache=False,
is_prefix_cache=True,
)
return output
......
......@@ -3095,6 +3095,57 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
if envs.VLLM_USE_FLASH_ATTN_PA and not kv_cache_spec.use_mla:
key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
key_stride_order, value_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(key_stride_order) == len(
key_cache_shape)
assert len(value_stride_order) == len(
value_cache_shape)
except (AttributeError, NotImplementedError):
key_stride_order = tuple(
range(len(key_cache_shape)))
value_stride_order = tuple(
range(len(value_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
key_cache_shape = tuple(key_cache_shape[i]
for i in key_stride_order)
value_cache_shape = tuple(value_cache_shape[i]
for i in value_stride_order)
# Maintain original KV shape view.
inv_key_order = [
key_stride_order.index(i)
for i in range(len(key_stride_order))
]
inv_value_order = [
value_stride_order.index(i)
for i in range(len(value_stride_order))
]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
total_elements = raw_tensor.numel()
key_elements = (key_cache_shape[0] * key_cache_shape[1] *
key_cache_shape[2] * key_cache_shape[3])
value_elements = (value_cache_shape[0] * value_cache_shape[1] *
value_cache_shape[2] * value_cache_shape[3])
assert total_elements == key_elements + value_elements
key_cache = raw_tensor[:key_elements].view(key_cache_shape).permute(
*inv_key_order)
value_cache = raw_tensor[key_elements:].view(value_cache_shape).permute(
*inv_value_order)
kv_caches[layer_name] = (key_cache, value_cache)
else:
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment