Commit b52c0d8c authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa_version

parent 92aba825
...@@ -28,7 +28,7 @@ elif current_platform.is_xpu(): ...@@ -28,7 +28,7 @@ elif current_platform.is_xpu():
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies # import here to avoid circular dependencies
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_xpu(): if current_platform.is_rocm() or current_platform.is_xpu():
return 2 return 2
try: try:
from vllm.vllm_flash_attn.flash_attn_interface import ( from vllm.vllm_flash_attn.flash_attn_interface import (
......
...@@ -644,7 +644,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -644,7 +644,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
q_descale=layer._q_scale, q_descale=layer._q_scale,
k_descale=layer._k_scale, k_descale=layer._k_scale,
...@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -674,7 +674,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
# fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale, q_descale=layer._q_scale,
...@@ -699,7 +699,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -699,7 +699,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table, block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len, common_prefix_len=attn_metadata.common_prefix_len,
# fa_version=2, #self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale, q_descale=layer._q_scale,
...@@ -778,7 +778,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -778,7 +778,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=self.sliding_window,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
...@@ -913,7 +913,7 @@ def cascade_attention( ...@@ -913,7 +913,7 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
...@@ -937,7 +937,7 @@ def cascade_attention( ...@@ -937,7 +937,7 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
...@@ -966,7 +966,7 @@ def cascade_attention( ...@@ -966,7 +966,7 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
...@@ -990,7 +990,7 @@ def cascade_attention( ...@@ -990,7 +990,7 @@ def cascade_attention(
softcap=logits_soft_cap, softcap=logits_soft_cap,
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, fa_version=fa_version,
q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None, if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment