Commit f9408aff authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa v1 interface

parent 41e6d686
...@@ -13,6 +13,10 @@ if current_platform.is_cuda(): ...@@ -13,6 +13,10 @@ if current_platform.is_cuda():
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata) get_scheduler_metadata)
elif current_platform.is_rocm():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
elif current_platform.is_xpu(): elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
...@@ -69,4 +73,4 @@ def flash_attn_supports_fp8() -> bool: ...@@ -69,4 +73,4 @@ def flash_attn_supports_fp8() -> bool:
def is_flash_attn_varlen_func_available() -> bool: def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu() return current_platform.is_cuda() or current_platform.is_rocm() or current_platform.is_xpu()
...@@ -17,9 +17,15 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, ...@@ -17,9 +17,15 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version, get_flash_attn_version,
is_flash_attn_varlen_func_available) is_flash_attn_varlen_func_available)
from vllm.platforms import current_platform
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, if not current_platform.is_rocm():
get_scheduler_metadata, from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash)
else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
vllm_flash_attn_varlen_func,
reshape_and_cache_flash) reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
...@@ -559,28 +565,53 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -559,28 +565,53 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func( if not current_platform.is_rocm():
q=query[:num_actual_tokens], flash_attn_varlen_func(
k=key_cache, q=query[:num_actual_tokens],
v=value_cache, k=key_cache,
out=output[:num_actual_tokens], v=value_cache,
cu_seqlens_q=cu_seqlens_q, out=output[:num_actual_tokens],
max_seqlen_q=max_seqlen_q, cu_seqlens_q=cu_seqlens_q,
seqused_k=seqused_k, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k, seqused_k=seqused_k,
softmax_scale=self.scale, max_seqlen_k=max_seqlen_k,
causal=True, softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, causal=True,
window_size=self.sliding_window, alibi_slopes=self.alibi_slopes,
block_table=block_table, window_size=self.sliding_window,
softcap=self.logits_soft_cap, block_table=block_table,
scheduler_metadata=scheduler_metadata, softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version, scheduler_metadata=scheduler_metadata,
q_descale=layer._q_scale.expand(descale_shape), fa_version=self.vllm_flash_attn_version,
k_descale=layer._k_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits, v_descale=layer._v_scale.expand(descale_shape),
) num_splits=attn_metadata.max_num_splits,
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False,
)
return output return output
assert not use_local_attn, ( assert not use_local_attn, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment