Commit 8e56ae5e authored by zhuwenwen's avatar zhuwenwen
Browse files

skip window_size and alibi_slopes because ck fa is not supported, and add triton fa

parent 0e640807
......@@ -271,11 +271,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
# from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
# flash_attn_varlen_func)
self.attn_func = triton_attention # flash_attn_varlen_func
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
# self.attn_func = triton_attention
from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
......@@ -391,17 +392,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
# out = self.attn_func(
# query,
# key,
# value,
# prefill_meta.seq_lens,
# num_tokens,
# self.num_heads,
# self.head_size,
# self.scale,
# attn_masks,
# )
out = self.attn_func(
query,
key,
value,
prefill_meta.seq_lens,
num_tokens,
self.num_heads,
self.head_size,
self.scale,
attn_masks,
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
......@@ -439,8 +452,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
# window_size=self.sliding_window,
# alibi_slopes=self.alibi_slopes,
)
# common code for prefill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment