Commit b52c0d8c authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa_version

parent 92aba825
......@@ -28,7 +28,7 @@ elif current_platform.is_xpu():
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
if current_platform.is_rocm() or current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
......
......@@ -644,7 +644,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
q_descale=layer._q_scale,
k_descale=layer._k_scale,
......@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
# fa_version=self.vllm_flash_attn_version,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
......@@ -699,7 +699,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
# fa_version=2, #self.vllm_flash_attn_version,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
......@@ -778,7 +778,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
......@@ -913,7 +913,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......@@ -937,7 +937,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......@@ -966,7 +966,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......@@ -990,7 +990,7 @@ def cascade_attention(
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
fa_version=fa_version,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment