Commit f9408aff authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa v1 interface

parent 41e6d686
......@@ -13,6 +13,10 @@ if current_platform.is_cuda():
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
elif current_platform.is_rocm():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
......@@ -69,4 +73,4 @@ def flash_attn_supports_fp8() -> bool:
def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()
return current_platform.is_cuda() or current_platform.is_rocm() or current_platform.is_xpu()
......@@ -17,9 +17,15 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version,
is_flash_attn_varlen_func_available)
from vllm.platforms import current_platform
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
get_scheduler_metadata,
if not current_platform.is_rocm():
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash)
else:
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
vllm_flash_attn_varlen_func,
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config
......@@ -559,28 +565,53 @@ class FlashAttentionImpl(AttentionImpl):
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
)
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
)
else:
vllm_flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache=False,
)
return output
assert not use_local_attn, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment