Commit f7beb354 authored by zhuwenwen's avatar zhuwenwen
Browse files

set v1 attention use fa

parent d3632a8b
...@@ -275,10 +275,9 @@ class RocmPlatform(Platform): ...@@ -275,10 +275,9 @@ class RocmPlatform(Platform):
# logger.info_once("Using Triton backend on V1 engine.") # logger.info_once("Using Triton backend on V1 engine.")
# return TRITON_ATTN_VLLM_V1 # return TRITON_ATTN_VLLM_V1
if envs.is_set("VLLM_USE_FLASH_ATTN_PA") and envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64: if envs.VLLM_USE_FLASH_ATTN_PA and block_size == 64:
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)") logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
else: else:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN_VLLM_V1
......
...@@ -131,7 +131,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -131,7 +131,7 @@ class FlashAttentionBackend(AttentionBackend):
value_stride_order = (0, 1, 2, 3) value_stride_order = (0, 1, 2, 3)
elif cache_layout == "HND": elif cache_layout == "HND":
key_stride_order = (0, 2, 1, 3) key_stride_order = (0, 2, 1, 3)
value_stride_order = (0, 3, 1, 2) value_stride_order = (0, 2, 1, 3)
else: else:
raise ValueError(f"Unknown cache layout format {cache_layout}.") raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order return key_stride_order, value_stride_order
...@@ -637,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -637,7 +637,7 @@ class FlashAttentionImpl(AttentionImpl):
else: else:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:") print("PA SIZE:")
print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}") print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}") print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment