Commit 8e56ae5e authored by zhuwenwen's avatar zhuwenwen
Browse files

skip window_size and alibi_slopes because ck fa is not supported, and add triton fa

parent 0e640807
...@@ -271,11 +271,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -271,11 +271,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 # from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention) # triton_attention)
# from vllm.attention.ops.flash_attn_triton_mqa_gqa import ( # self.attn_func = triton_attention
# flash_attn_varlen_func) from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
self.attn_func = triton_attention # flash_attn_varlen_func flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1): if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support " logger.warning("ROCm Triton FA does not currently support "
...@@ -391,17 +392,29 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -391,17 +392,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
attn_metadata.seq_lens, attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore make_attn_mask=False) # type: ignore
# out = self.attn_func(
# query,
# key,
# value,
# prefill_meta.seq_lens,
# num_tokens,
# self.num_heads,
# self.head_size,
# self.scale,
# attn_masks,
# )
out = self.attn_func( out = self.attn_func(
query, q=query,
key, k=key,
value, v=value,
prefill_meta.seq_lens, cu_seqlens_q=prefill_meta.seq_start_loc,
num_tokens, cu_seqlens_k=prefill_meta.seq_start_loc,
self.num_heads, max_seqlens_q=prefill_meta.max_prefill_seq_len,
self.head_size, max_seqlens_k=prefill_meta.max_prefill_seq_len,
self.scale, softmax_scale=self.scale,
attn_masks, causal=True,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround. # Interleave for MQA workaround.
...@@ -439,8 +452,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -439,8 +452,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, # window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, # alibi_slopes=self.alibi_slopes,
) )
# common code for prefill # common code for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment