Commit 981f8370 authored by zhuwenwen's avatar zhuwenwen
Browse files

update v1 fa layout and set v1 attention use fa

parent c0f0b209
...@@ -15,8 +15,8 @@ if current_platform.is_cuda(): ...@@ -15,8 +15,8 @@ if current_platform.is_cuda():
get_scheduler_metadata) get_scheduler_metadata)
elif current_platform.is_rocm(): elif current_platform.is_rocm():
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_cuda = ops.reshape_and_cache_cuda
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func from flash_attn import vllm_flash_attn_varlen_func
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
......
...@@ -289,7 +289,7 @@ class RocmPlatform(Platform): ...@@ -289,7 +289,7 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.") # logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1 # return TRITON_ATTN_VLLM_V1
if envs.is_set("VLLM_USE_FLASH_ATTN_PA") and envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64: if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)") logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available(): ...@@ -25,9 +25,8 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata, get_scheduler_metadata,
reshape_and_cache_flash) reshape_and_cache_flash)
else: else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
vllm_flash_attn_varlen_func, reshape_and_cache_cuda)
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
...@@ -83,30 +82,60 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -83,30 +82,60 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
if not current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
else:
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
num_blocks: int, # `stride_order` indicates the permutation that gets
block_size: int, # us from `get_kv_cache_shape` to the actual memory layout we want.
num_kv_heads: int, cache_layout = get_kv_cache_layout()
head_size: int, if cache_layout == "NHD":
) -> tuple[int, ...]: key_stride_order = (0, 1, 2, 3)
if block_size % 16 != 0: value_stride_order = (0, 1, 2, 3)
raise ValueError("Block size must be a multiple of 16.") elif cache_layout == "HND":
return (2, num_blocks, block_size, num_kv_heads, head_size) key_stride_order = (0, 2, 1, 3)
value_stride_order = (0, 2, 1, 3)
@staticmethod else:
def get_kv_cache_stride_order() -> tuple[int, ...]: raise ValueError(f"Unknown cache layout format {cache_layout}.")
# `stride_order` indicates the permutation that gets return key_stride_order, value_stride_order
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod @staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
...@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -512,7 +541,10 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata, layer) attn_metadata, layer)
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0) if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
...@@ -522,16 +554,28 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -522,16 +554,28 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash # and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
reshape_and_cache_flash( if not current_platform.is_rocm():
key, reshape_and_cache_flash(
value, key,
key_cache, value,
value_cache, key_cache,
attn_metadata.slot_mapping, value_cache,
self.kv_cache_dtype, attn_metadata.slot_mapping,
layer._k_scale, self.kv_cache_dtype,
layer._v_scale, layer._k_scale,
) layer._v_scale,
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
...@@ -582,7 +626,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -582,7 +626,7 @@ class FlashAttentionImpl(AttentionImpl):
else: else:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:") print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}") print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}") print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func( vllm_flash_attn_varlen_func(
...@@ -607,7 +651,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -607,7 +651,7 @@ class FlashAttentionImpl(AttentionImpl):
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks, # s_aux=self.sinks,
is_prefix_cache=False, is_prefix_cache=True,
) )
return output return output
......
...@@ -3095,33 +3095,84 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3095,33 +3095,84 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_spec.page_size_bytes) kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape( if envs.VLLM_USE_FLASH_ATTN_PA and not kv_cache_spec.use_mla:
num_blocks, kv_cache_spec.block_size, key_cache_shape, value_cache_shape = attn_backend.get_kv_cache_shape(
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) num_blocks, kv_cache_spec.block_size,
dtype = kv_cache_spec.dtype kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
try: dtype = kv_cache_spec.dtype
kv_cache_stride_order = \ try:
attn_backend.get_kv_cache_stride_order() key_stride_order, value_stride_order = attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len( assert len(key_stride_order) == len(
kv_cache_shape) key_cache_shape)
except (AttributeError, NotImplementedError): assert len(value_stride_order) == len(
kv_cache_stride_order = tuple( value_cache_shape)
range(len(kv_cache_shape))) except (AttributeError, NotImplementedError):
# The allocation respects the backend-defined stride order key_stride_order = tuple(
# to ensure the semantic remains consistent for each range(len(key_cache_shape)))
# backend. We first obtain the generic kv cache shape and value_stride_order = tuple(
# then permute it according to the stride order which could range(len(value_cache_shape)))
# result in a non-contiguous tensor. # The allocation respects the backend-defined stride order
kv_cache_shape = tuple(kv_cache_shape[i] # to ensure the semantic remains consistent for each
for i in kv_cache_stride_order) # backend. We first obtain the generic kv cache shape and
# Maintain original KV shape view. # then permute it according to the stride order which could
inv_order = [ # result in a non-contiguous tensor.
kv_cache_stride_order.index(i) key_cache_shape = tuple(key_cache_shape[i]
for i in range(len(kv_cache_stride_order)) for i in key_stride_order)
] value_cache_shape = tuple(value_cache_shape[i]
kv_caches[layer_name] = kv_cache_raw_tensors[ for i in value_stride_order)
layer_name].view(dtype).view(kv_cache_shape).permute( # Maintain original KV shape view.
*inv_order) inv_key_order = [
key_stride_order.index(i)
for i in range(len(key_stride_order))
]
inv_value_order = [
value_stride_order.index(i)
for i in range(len(value_stride_order))
]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
total_elements = raw_tensor.numel()
key_elements = (key_cache_shape[0] * key_cache_shape[1] *
key_cache_shape[2] * key_cache_shape[3])
value_elements = (value_cache_shape[0] * value_cache_shape[1] *
value_cache_shape[2] * value_cache_shape[3])
assert total_elements == key_elements + value_elements
key_cache = raw_tensor[:key_elements].view(key_cache_shape).permute(
*inv_key_order)
value_cache = raw_tensor[key_elements:].view(value_cache_shape).permute(
*inv_value_order)
kv_caches[layer_name] = (key_cache, value_cache)
else:
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
elif isinstance(kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name] raw_tensor = kv_cache_raw_tensors[layer_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment