Commit ead74dfa authored by zhuwenwen's avatar zhuwenwen
Browse files

support cascade_attention

parent d6dc122f
...@@ -669,30 +669,56 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -669,30 +669,56 @@ class FlashAttentionImpl(AttentionImpl):
assert not use_local_attn, ( assert not use_local_attn, (
"Cascade attention does not support local attention.") "Cascade attention does not support local attention.")
# Cascade attention (rare case). # Cascade attention (rare case).
cascade_attention( if not current_platform.is_rocm():
output[:num_actual_tokens], cascade_attention(
query[:num_actual_tokens], output[:num_actual_tokens],
key_cache, query[:num_actual_tokens],
value_cache, key_cache,
cu_query_lens=attn_metadata.query_start_loc, value_cache,
max_query_len=attn_metadata.max_query_len, cu_query_lens=attn_metadata.query_start_loc,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens, max_query_len=attn_metadata.max_query_len,
prefix_kv_lens=attn_metadata.prefix_kv_lens, cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens, prefix_kv_lens=attn_metadata.prefix_kv_lens,
max_kv_len=attn_metadata.max_seq_len, suffix_kv_lens=attn_metadata.suffix_kv_lens,
softmax_scale=self.scale, max_kv_len=attn_metadata.max_seq_len,
alibi_slopes=self.alibi_slopes, softmax_scale=self.scale,
sliding_window=self.sliding_window, alibi_slopes=self.alibi_slopes,
logits_soft_cap=self.logits_soft_cap, sliding_window=self.sliding_window,
block_table=attn_metadata.block_table, logits_soft_cap=self.logits_soft_cap,
common_prefix_len=attn_metadata.common_prefix_len, block_table=attn_metadata.block_table,
fa_version=self.vllm_flash_attn_version, common_prefix_len=attn_metadata.common_prefix_len,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, fa_version=self.vllm_flash_attn_version,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
q_descale=layer._q_scale, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
k_descale=layer._k_scale, q_descale=layer._q_scale,
v_descale=layer._v_scale, k_descale=layer._k_scale,
) v_descale=layer._v_scale,
)
else:
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
)
return output return output
...@@ -825,6 +851,31 @@ def cascade_attention( ...@@ -825,6 +851,31 @@ def cascade_attention(
v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None, if v_descale is not None else None,
) )
else:
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache=True,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
...@@ -853,6 +904,31 @@ def cascade_attention( ...@@ -853,6 +904,31 @@ def cascade_attention(
v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None, if v_descale is not None else None,
) )
else:
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache=True,
)
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment