Commit 3b81dd6c authored by zhuwenwen's avatar zhuwenwen
Browse files

[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash...

[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models.
parent b5a9a18d
# SPDX-License-Identifier: Apache-2.0
"""Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
......@@ -351,28 +352,27 @@ def _get_seq_len_block_table_args(
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Encoder-only attn -> select prefill sequence lengths with
bidirectional attention
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
encoder/decoder cross-attention, encoder-only
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
* Causal masking flag
'''
partial_prefix_sum = 0
if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False
......@@ -381,16 +381,27 @@ def _get_seq_len_block_table_args(
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, we use the prefill sequence lengths
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
# Encoder-only models typically use bidirectional attention
causal_mask = False
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
......@@ -402,21 +413,14 @@ def _get_seq_len_block_table_args(
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
partial_prefix_sum = 0
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor(
[0] + [
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False
......@@ -598,6 +602,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
* ENCODER_ONLY: bidirectional attention with no KV caching;
use prefill sequence attributes
Args:
query: shape = [num_tokens, num_heads * head_size]
......@@ -622,7 +628,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
assert value is None
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
......@@ -646,6 +656,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens
elif self.attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, all tokens are processed in one go
num_prefill_tokens = query.shape[0]
else:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
......@@ -656,8 +669,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# QKV for prefill.
query = query[:num_prefill_tokens]
# For encoder-only and encoder models,
# we process all tokens at once
# For decoder and encoder-decoder,
# we may need to limit key/value to prefill tokens
if key is not None and value is not None \
and self.attn_type != AttentionType.ENCODER_DECODER:
and self.attn_type not in [AttentionType.ENCODER_DECODER,
AttentionType.ENCODER_ONLY]:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
......@@ -692,7 +710,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func(
query,
key,
......@@ -718,7 +736,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore
make_attn_mask=causal_mask) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
......@@ -745,7 +763,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=True,
causal=causal_mask,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
......@@ -759,25 +777,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
output = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
if decode_meta := attn_metadata.decode_metadata:
# not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:
num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
# Decoding run.
# Whether to use rocm custom paged attention or not
num_seqs, num_heads, head_size = decode_query.shape
......@@ -906,4 +928,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment