Commit 1863c926 authored by zhuwenwen's avatar zhuwenwen
Browse files

Use triton fa by default

parent b6247705
...@@ -229,9 +229,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -229,9 +229,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# NOTE: Allow for switching between Triton and CK. Defaulting to triton. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 # from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention) # triton_attention)
self.attn_func = triton_attention from vllm.attention.ops.flash_attn_triton_mqa_gqa import (
flash_attn_varlen_func)
self.attn_func = flash_attn_varlen_func # triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
...@@ -325,17 +327,27 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -325,17 +327,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
out, _ = self.attn_func( # out, _ = self.attn_func(
query, # query,
key, # key,
value, # value,
None, # None,
prefill_meta.seq_start_loc, # prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc, # prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len, # prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len, # prefill_meta.max_prefill_seq_len,
True, # True,
self.scale, # self.scale,
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlens_q=prefill_meta.max_prefill_seq_len,
max_seqlens_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
) )
elif self.use_naive_attn: elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
......
This diff is collapsed.
...@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment