Commit bdc70899 authored by zhuwenwen's avatar zhuwenwen
Browse files

support cutlass prefix-cache

parent 09e372e7
......@@ -575,6 +575,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
......@@ -843,24 +847,68 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
# prefix-enabled attention -
# not applicable for encoder-only models
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
else:
assert self.attn_type != AttentionType.ENCODER_ONLY, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
assert prefill_meta.query_start_loc is not None
max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
'''
k_cache
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i
'''
num_blocks, num_kv_heads, head_size_div_x, block_size, x = key_cache.shape
head_size = head_size_div_x * x
key_cache_flash = key_cache.permute(0, 3, 1, 2, 4) # [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash = key_cache_flash.reshape(num_blocks, block_size, num_kv_heads, head_size)
# value_cache
value_cache_flash = value_cache.permute(0, 3, 1, 2) # [num_blocks, block_size, num_kv_heads, head_size]
output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa
q=query,
k=key_cache_flash,
v=value_cache_flash,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
......
......@@ -17,6 +17,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False
......@@ -272,6 +273,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use triton prefix flash attention
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment