Commit bdd33b3f authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface and kvcache

add prepare_so_files to prepare so
parent 63053820
......@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
3.跳过编译(适用于未改变csrc目录kernel并多次编译情况)
将编译后的so文件拷贝至csrc目录,并设置环境变量: export SKIP_VLLM_BUILD=1
#### 运行基础环境准备
1、使用上面基于光源pytorch2.9.0基础镜像环境
......
......@@ -13,6 +13,8 @@ import sys
import sysconfig
from pathlib import Path
from shutil import which
import tarfile
import shutil
import torch
from packaging.version import Version, parse
......@@ -36,6 +38,37 @@ skip_vllm_build = False
if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1:
skip_vllm_build = True
def prepare_so_files():
source_dir = "csrc/so.tar.gz"
target_dir = "vllm"
if not os.path.exists(source_dir):
print(f"Warning: {source_dir} not found, skipping extraction")
return
print(f"Preparing C extension files from {source_dir}...")
temp_dir = "temp_so_extract"
os.makedirs(temp_dir, exist_ok=True)
try:
with tarfile.open(source_dir, "r:*") as tar:
tar.extractall(temp_dir)
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file in ["_C.abi3.so", "_moe_C.abi3.so", "cumem_allocator.abi3.so"]:
src_path = os.path.join(root, file)
dst_path = os.path.join(target_dir, file)
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)
print(f"Copied {file} to {dst_path}")
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
......@@ -1109,6 +1142,7 @@ if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
if skip_vllm_build:
prepare_so_files()
package_data = {
"vllm": [
"py.typed",
......
......@@ -848,7 +848,10 @@ def unified_kv_cache_update(
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
if current_platform.is_rocm():
return torch.empty(0, device=key.device, dtype=key.dtype)
else:
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def unified_kv_cache_update_fake(
......
......@@ -27,18 +27,18 @@ from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.platforms import current_platform
if is_flash_attn_varlen_func_available():
if not current_platform.is_rocm():
if current_platform.is_rocm():
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash,
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
)
else:
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
vllm_flash_attn_varlen_func,
reshape_and_cache_cuda,
flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
......@@ -113,7 +113,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
if not current_platform.is_rocm():
if current_platform.is_rocm():
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -121,31 +121,36 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
) -> tuple[tuple[int, ...], tuple[int, ...]]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
) -> tuple[tuple[int, ...], tuple[int, ...]]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
return key_stride_order, value_stride_order
else:
@staticmethod
def get_kv_cache_shape(
......@@ -154,36 +159,32 @@ class FlashAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[tuple[int, ...], tuple[int, ...]]:
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (
(num_blocks, num_kv_heads, block_size, head_size),
(num_blocks, num_kv_heads, head_size, block_size),
)
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[tuple[int, ...], tuple[int, ...]]:
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, block_size, num_kv_heads, head_size)
return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
elif cache_layout == "NHD":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 2, 3)
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
key_stride_order = (0, 1, 2, 3)
value_stride_order = (0, 1, 3, 2)
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
return stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
......@@ -724,10 +725,10 @@ class FlashAttentionImpl(AttentionImpl):
)
# For decoder and cross-attention, use KV cache as before
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
if current_platform.is_rocm():
key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
......@@ -745,12 +746,16 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
if not current_platform.is_rocm():
if current_platform.is_rocm():
q_descale = None
k_descale = layer._k_scale
v_descale = layer._v_scale
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
if self.dcp_world_size > 1:
self._forward_with_dcp(
......@@ -772,8 +777,13 @@ class FlashAttentionImpl(AttentionImpl):
if self.sliding_window is not None
else None
)
if not current_platform.is_rocm():
flash_attn_varlen_func(
if current_platform.is_rocm():
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
......@@ -793,16 +803,12 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
else:
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
vllm_flash_attn_varlen_func(
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
......@@ -818,76 +824,42 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
)
return output
# Cascade attention (rare case).
if not current_platform.is_rocm():
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
else:
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
max_num_splits=attn_metadata.max_num_splits,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=None if current_platform.is_rocm() else layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
s_aux=self.sinks,
)
return output
def do_kv_cache_update(
......@@ -913,10 +885,10 @@ class FlashAttentionImpl(AttentionImpl):
):
return
if not current_platform.is_rocm():
key_cache, value_cache = kv_cache.unbind(0)
else:
if current_platform.is_rocm():
key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
......@@ -925,18 +897,7 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if not current_platform.is_rocm():
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
if current_platform.is_rocm():
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
......@@ -961,6 +922,17 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def _forward_with_dcp(
......@@ -989,28 +961,53 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
if current_platform.is_rocm():
context_attn_out, context_lse = vllm_flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
is_prefix_cache=True,
)
else:
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
v=value_cache,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens,
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=attn_metadata.scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
context_attn_out,
......@@ -1020,26 +1017,49 @@ class FlashAttentionImpl(AttentionImpl):
)
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
query_attn_out, query_lse = flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
if current_platform.is_rocm():
query_attn_out, query_lse = vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
query_attn_out, query_lse = flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=None,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
merge_attn_states(
......@@ -1094,8 +1114,8 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
if not current_platform.is_rocm():
flash_attn_varlen_func(
if current_platform.is_rocm():
vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
......@@ -1109,14 +1129,18 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache=False,
)
else:
vllm_flash_attn_varlen_func(
flash_attn_varlen_func(
q=query,
k=key,
v=value,
......@@ -1130,15 +1154,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache=False,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
)
return output
......@@ -1259,11 +1279,12 @@ def cascade_attention(
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
assert num_common_kv_blocks > 0
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
if not current_platform.is_rocm():
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process shared prefix.
if not current_platform.is_rocm():
prefix_output, prefix_lse = flash_attn_varlen_func(
if current_platform.is_rocm():
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -1279,16 +1300,17 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
q_descale=q_descale if q_descale is not None else None,
k_descale=k_descale if k_descale is not None else None,
v_descale=v_descale if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
)
else:
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -1303,22 +1325,21 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
# Process suffix per query.
if not current_platform.is_rocm():
suffix_output, suffix_lse = flash_attn_varlen_func(
if current_platform.is_rocm():
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -1334,13 +1355,14 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
q_descale=q_descale if q_descale is not None else None,
k_descale=k_descale if k_descale is not None else None,
v_descale=v_descale if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
)
else:
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -1355,12 +1377,11 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
# num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
is_prefix_cache=True,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment