Commit 3b81dd6c authored by zhuwenwen's avatar zhuwenwen
Browse files

[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash...

[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models.
parent b5a9a18d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer ROCm GPUs.""" """Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
...@@ -351,28 +352,27 @@ def _get_seq_len_block_table_args( ...@@ -351,28 +352,27 @@ def _get_seq_len_block_table_args(
Decoder attn -> select entirely decoder self-attention-related fields Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields Encoder attn -> select encoder sequence lengths fields
Encoder-only attn -> select prefill sequence lengths with
bidirectional attention
Arguments: Arguments:
* attn_metadata: Attention metadata structure associated with attention op * attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention, * attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention encoder/decoder cross-attention, encoder-only
Returns: Returns:
* Appropriate sequence-lengths tensors for query and key * Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar * Appropriate max sequence-length scalar
* Causal masking flag
''' '''
partial_prefix_sum = 0
if attn_type == AttentionType.ENCODER: if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor( query_seq_start_loc = torch.tensor(
[0] + [ list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device, device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype) dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False causal_mask = False
...@@ -381,16 +381,27 @@ def _get_seq_len_block_table_args( ...@@ -381,16 +381,27 @@ def _get_seq_len_block_table_args(
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len, query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask) attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, we use the prefill sequence lengths
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
# Encoder-only models typically use bidirectional attention
causal_mask = False
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER: elif attn_type == AttentionType.DECODER:
# Decoder self-attention # Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run # Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor( query_seq_start_loc = torch.tensor(
[0] + [ list(itertools.accumulate([0] + attn_metadata.seq_lens)),
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.seq_lens_tensor.device, device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype) dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len max_seq_len = attn_metadata.max_prefill_seq_len
...@@ -402,21 +413,14 @@ def _get_seq_len_block_table_args( ...@@ -402,21 +413,14 @@ def _get_seq_len_block_table_args(
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor( query_start_loc = torch.tensor(
[0] + [ list(itertools.accumulate([0] + attn_metadata.seq_lens)),
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.seq_lens
],
device=attn_metadata.encoder_seq_lens_tensor.device, device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype) dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
partial_prefix_sum = 0
assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor( key_seq_start_loc = torch.tensor(
[0] + [ list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
partial_prefix_sum := partial_prefix_sum + i
for i in attn_metadata.encoder_seq_lens
],
device=attn_metadata.seq_lens_tensor.device, device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype) dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False causal_mask = False
...@@ -598,6 +602,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -598,6 +602,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
will match encoder sequence lengths, pass encoder sequence will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) max_encoder_seq_len)
* ENCODER_ONLY: bidirectional attention with no KV caching;
use prefill sequence attributes
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
...@@ -622,7 +628,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -622,7 +628,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
assert value is None assert value is None
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: # Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size)
...@@ -646,6 +656,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -646,6 +656,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if self.attn_type != AttentionType.ENCODER: if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
elif self.attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, all tokens are processed in one go
num_prefill_tokens = query.shape[0]
else: else:
assert attn_metadata.num_encoder_tokens is not None assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens num_prefill_tokens = attn_metadata.num_encoder_tokens
...@@ -656,8 +669,13 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -656,8 +669,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_tokens] query = query[:num_prefill_tokens]
# For encoder-only and encoder models,
# we process all tokens at once
# For decoder and encoder-decoder,
# we may need to limit key/value to prefill tokens
if key is not None and value is not None \ if key is not None and value is not None \
and self.attn_type != AttentionType.ENCODER_DECODER: and self.attn_type not in [AttentionType.ENCODER_DECODER,
AttentionType.ENCODER_ONLY]:
key = key[:num_prefill_tokens] key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens] value = value[:num_prefill_tokens]
...@@ -692,7 +710,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -692,7 +710,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes, self.alibi_slopes,
query.dtype, query.dtype,
seq_lens, seq_lens,
make_attn_mask=False) # type: ignore make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func( out, _ = self.attn_func(
query, query,
key, key,
...@@ -718,7 +736,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -718,7 +736,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes, self.alibi_slopes,
query.dtype, query.dtype,
attn_metadata.seq_lens, attn_metadata.seq_lens,
make_attn_mask=True) # type: ignore make_attn_mask=causal_mask) # type: ignore
query = query.movedim(0, query.dim() - 2) query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
...@@ -745,7 +763,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -745,7 +763,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_q=prefill_meta.max_prefill_seq_len, max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len, max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=causal_mask,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
...@@ -759,25 +777,29 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -759,25 +777,29 @@ class ROCmFlashAttentionImpl(AttentionImpl):
output = out output = out
else: else:
# prefix-enabled attention # prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix( # not applicable for encoder-only models
query, if self.attn_type != AttentionType.ENCODER_ONLY:
key, output[:
value, num_prefill_tokens] = PagedAttention.forward_prefix(
self.kv_cache_dtype, query,
key_cache, key,
value_cache, value,
prefill_meta.block_tables, self.kv_cache_dtype,
prefill_meta.query_start_loc, key_cache,
prefill_meta.seq_lens_tensor, value_cache,
prefill_meta.context_lens_tensor, prefill_meta.block_tables,
prefill_meta.max_query_len, prefill_meta.query_start_loc,
self.alibi_slopes, prefill_meta.seq_lens_tensor,
self.sliding_window[0], prefill_meta.max_query_len,
layer._k_scale, self.alibi_slopes,
layer._v_scale, self.sliding_window[0],
) layer._k_scale,
layer._v_scale,
if decode_meta := attn_metadata.decode_metadata: )
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
# Decoding run. # Decoding run.
# Whether to use rocm custom paged attention or not # Whether to use rocm custom paged attention or not
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
...@@ -906,4 +928,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, ...@@ -906,4 +928,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (qtype == torch.half or qtype == torch.bfloat16) and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment