Commit bdd33b3f authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface and kvcache

add prepare_so_files to prepare so
parent 63053820
...@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop) ...@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
``` ```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1 若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
3.跳过编译(适用于未改变csrc目录kernel并多次编译情况)
将编译后的so文件拷贝至csrc目录,并设置环境变量: export SKIP_VLLM_BUILD=1
#### 运行基础环境准备 #### 运行基础环境准备
1、使用上面基于光源pytorch2.9.0基础镜像环境 1、使用上面基于光源pytorch2.9.0基础镜像环境
......
...@@ -13,6 +13,8 @@ import sys ...@@ -13,6 +13,8 @@ import sys
import sysconfig import sysconfig
from pathlib import Path from pathlib import Path
from shutil import which from shutil import which
import tarfile
import shutil
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
...@@ -36,6 +38,37 @@ skip_vllm_build = False ...@@ -36,6 +38,37 @@ skip_vllm_build = False
if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1: if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1:
skip_vllm_build = True skip_vllm_build = True
def prepare_so_files():
source_dir = "csrc/so.tar.gz"
target_dir = "vllm"
if not os.path.exists(source_dir):
print(f"Warning: {source_dir} not found, skipping extraction")
return
print(f"Preparing C extension files from {source_dir}...")
temp_dir = "temp_so_extract"
os.makedirs(temp_dir, exist_ok=True)
try:
with tarfile.open(source_dir, "r:*") as tar:
tar.extractall(temp_dir)
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file in ["_C.abi3.so", "_moe_C.abi3.so", "cumem_allocator.abi3.so"]:
src_path = os.path.join(root, file)
dst_path = os.path.join(target_dir, file)
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)
print(f"Copied {file} to {dst_path}")
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path) spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
...@@ -1109,6 +1142,7 @@ if _build_custom_ops(): ...@@ -1109,6 +1142,7 @@ if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
if skip_vllm_build: if skip_vllm_build:
prepare_so_files()
package_data = { package_data = {
"vllm": [ "vllm": [
"py.typed", "py.typed",
......
...@@ -848,6 +848,9 @@ def unified_kv_cache_update( ...@@ -848,6 +848,9 @@ def unified_kv_cache_update(
layer_slot_mapping, layer_slot_mapping,
) )
if current_platform.is_rocm():
return torch.empty(0, device=key.device, dtype=key.dtype)
else:
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
......
...@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states ...@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.platforms import current_platform from vllm.platforms import current_platform
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
if not current_platform.is_rocm(): if current_platform.is_rocm():
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks, flash_attn_supports_sinks,
flash_attn_varlen_func, vllm_flash_attn_varlen_func,
get_scheduler_metadata, reshape_and_cache_cuda,
reshape_and_cache_flash,
) )
else: else:
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks, flash_attn_supports_sinks,
vllm_flash_attn_varlen_func, flash_attn_varlen_func,
reshape_and_cache_cuda, get_scheduler_metadata,
reshape_and_cache_flash,
) )
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
...@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
if not current_platform.is_rocm(): if current_platform.is_rocm():
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size) return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod @staticmethod
def get_kv_cache_stride_order( def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,
) -> tuple[int, ...]: ) -> tuple[tuple[int, ...], tuple[int, ...]]:
# `stride_order` indicates the permutation that gets # `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want. # us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout() cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension: if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) # (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5) return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD": elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4) key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension: elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) # (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (2, 4, 0, 1, 3, 5) return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND": elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4) key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else: else:
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order return key_stride_order, value_stride_order
else: else:
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> tuple[tuple[int, ...], tuple[int, ...]]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
return ( return (2, num_blocks, block_size, num_kv_heads, head_size)
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod @staticmethod
def get_kv_cache_stride_order( def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False, include_num_layers_dimension: bool = False,
) -> tuple[tuple[int, ...], tuple[int, ...]]: ) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets # `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want. # us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout() cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension: if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size) # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3) return (2, 0, 1, 3, 4, 5)
elif cache_layout == "NHD": elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3) stride_order = (0, 1, 2, 3, 4)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension: elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size) # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3) return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND": elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3) stride_order = (0, 1, 3, 2, 4)
value_stride_order = (0, 1, 3, 2)
else: else:
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order return stride_order
@staticmethod @staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
...@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl):
) )
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
if not current_platform.is_rocm(): if current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer # queries are quantized in the attention layer
...@@ -745,7 +746,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -745,7 +746,11 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata scheduler_metadata = attn_metadata.scheduler_metadata
if not current_platform.is_rocm(): if current_platform.is_rocm():
q_descale = None
k_descale = layer._k_scale
v_descale = layer._v_scale
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape) q_descale = layer._q_scale.expand(descale_shape)
...@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl):
if self.sliding_window is not None if self.sliding_window is not None
else None else None
) )
if not current_platform.is_rocm(): if current_platform.is_rocm():
flash_attn_varlen_func( if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale, q_descale=q_descale,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks, s_aux=self.sinks,
is_prefix_cache=True,
) )
else: else:
if envs.VLLM_USE_PA_PRINT_PARAM: flash_attn_varlen_func(
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -818,21 +824,16 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -818,21 +824,16 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale, q_descale=q_descale,
# k_descale=k_descale, k_descale=k_descale,
# v_descale=v_descale, v_descale=v_descale,
q_descale=None, num_splits=attn_metadata.max_num_splits,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks, s_aux=self.sinks,
is_prefix_cache=True,
) )
return output return output
# Cascade attention (rare case). # Cascade attention (rare case).
if not current_platform.is_rocm():
cascade_attention( cascade_attention(
output[:num_actual_tokens], output[:num_actual_tokens],
query[:num_actual_tokens], query[:num_actual_tokens],
...@@ -854,36 +855,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -854,36 +855,7 @@ class FlashAttentionImpl(AttentionImpl):
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale, q_descale=None if current_platform.is_rocm() else layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
else:
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale, v_descale=layer._v_scale,
s_aux=self.sinks, s_aux=self.sinks,
...@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl):
): ):
return return
if not current_platform.is_rocm(): if current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
...@@ -925,8 +897,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -925,8 +897,10 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash # and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
if not current_platform.is_rocm(): if current_platform.is_rocm():
reshape_and_cache_flash( if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key, key,
value, value,
key_cache, key_cache,
...@@ -934,11 +908,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -934,11 +908,10 @@ class FlashAttentionImpl(AttentionImpl):
slot_mapping, slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale
) )
else: else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16: from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
value, value,
...@@ -947,11 +920,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -947,11 +920,10 @@ class FlashAttentionImpl(AttentionImpl):
slot_mapping, slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale layer._v_scale,
) )
else: else:
from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda reshape_and_cache_flash(
reshape_and_cache_cuda(
key, key,
value, value,
key_cache, key_cache,
...@@ -989,6 +961,31 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -989,6 +961,31 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = ( sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None list(self.sliding_window) if self.sliding_window is not None else None
) )
if current_platform.is_rocm():
context_attn_out, context_lse = vllm_flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
is_prefix_cache=True,
)
else:
context_attn_out, context_lse = flash_attn_varlen_func( context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp, q=query_across_dcp,
k=key_cache, k=key_cache,
...@@ -1020,6 +1017,28 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1020,6 +1017,28 @@ class FlashAttentionImpl(AttentionImpl):
) )
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
if current_platform.is_rocm():
query_attn_out, query_lse = vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
query_attn_out, query_lse = flash_attn_varlen_func( query_attn_out, query_lse = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
...@@ -1040,6 +1059,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1040,6 +1059,7 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
) )
assert context_attn_out_cor.shape == query_attn_out.shape assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape assert context_lse_cor.shape == query_lse.shape
merge_attn_states( merge_attn_states(
...@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = ( sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None list(self.sliding_window) if self.sliding_window is not None else None
) )
if not current_platform.is_rocm(): if current_platform.is_rocm():
flash_attn_varlen_func( vllm_flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size, window_size=sliding_window_size,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, # fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0, q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache=False,
) )
else: else:
vllm_flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size, window_size=sliding_window_size,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
q_descale=None, num_splits=1 if self.batch_invariant_enabled else 0,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache=False,
) )
return output return output
...@@ -1259,11 +1279,12 @@ def cascade_attention( ...@@ -1259,11 +1279,12 @@ def cascade_attention(
assert common_prefix_len % block_size == 0 assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0 assert num_common_kv_blocks > 0
if not current_platform.is_rocm():
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix. # Process shared prefix.
if not current_platform.is_rocm(): if current_platform.is_rocm():
prefix_output, prefix_lse = flash_attn_varlen_func( prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -1279,16 +1300,17 @@ def cascade_attention( ...@@ -1279,16 +1300,17 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel, # s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge. # enabling its effect during the final attention merge.
s_aux=s_aux, s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits, # num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
) )
else: else:
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func( prefix_output, prefix_lse = flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -1303,22 +1325,21 @@ def cascade_attention( ...@@ -1303,22 +1325,21 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel, # s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge. # enabling its effect during the final attention merge.
s_aux=s_aux, s_aux=s_aux,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits, num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query. # Process suffix per query.
if not current_platform.is_rocm(): if current_platform.is_rocm():
suffix_output, suffix_lse = flash_attn_varlen_func( suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -1334,13 +1355,14 @@ def cascade_attention( ...@@ -1334,13 +1355,14 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits, # num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
) )
else: else:
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func( suffix_output, suffix_lse = flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -1355,12 +1377,11 @@ def cascade_attention( ...@@ -1355,12 +1377,11 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits, num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment