Unverified Commit aaec845f authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[ROCm] [Attention] Cleanup ROCm output passing (#16431)


Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
parent 7bdfd29a
...@@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256 ...@@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
...@@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention) triton_attention)
self.attn_func = triton_attention self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1): if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support " logger.warning("ROCm Triton FA does not currently support "
...@@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend") logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
...@@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support " "ROCm Naive FlashAttention does not support "
"attention logits soft capping.") "attention logits soft capping.")
self.attn_func = _sdpa_attention self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend") logger.debug("Using naive (SDPA) attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
...@@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
if key is not None: if key is not None:
assert value is not None assert value is not None
...@@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert attn_metadata.num_encoder_tokens is not None assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens num_prefill_tokens = attn_metadata.num_encoder_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
...@@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype, query.dtype,
seq_lens, seq_lens,
make_attn_mask=causal_mask) # type: ignore make_attn_mask=causal_mask) # type: ignore
out, _ = self.attn_func( self.triton_attn_func(
query, query,
key, key,
value, value,
None, output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
key_seq_start_loc, key_seq_start_loc,
query_max_seq_len, query_max_seq_len,
...@@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention # sdpa math backend attention
out = self.attn_func( self.sdpa_attn_func(
query, query,
key, key,
value, value,
output[:num_prefill_tokens],
query_seq_start_loc, query_seq_start_loc,
num_prefill_tokens, num_prefill_tokens,
self.num_heads, self.num_heads,
...@@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
out = self.attn_func( # upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
if output.shape[0] > num_prefill_tokens:
output[:num_prefill_tokens] = out
else:
output = out
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
...@@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device=output.device, device=output.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
query_start_loc = None query_start_loc = None
ops.paged_attention_rocm( ops.paged_attention_rocm(
out, output[num_prefill_tokens:],
exp_sums, exp_sums,
max_logits, max_logits,
tmp_output, tmp_output,
...@@ -878,7 +872,8 @@ def _sdpa_attention( ...@@ -878,7 +872,8 @@ def _sdpa_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
seq_lens: List[int], output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
...@@ -886,9 +881,9 @@ def _sdpa_attention( ...@@ -886,9 +881,9 @@ def _sdpa_attention(
attn_masks: Optional[List[torch.Tensor]] = None, attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
start = 0 start = 0
output = torch.empty((num_tokens, num_heads, head_size), assert output.shape == (num_tokens, num_heads, head_size)
dtype=query.dtype, assert output.dtype == query.dtype
device=query.device) assert output.device == query.device
for i, seq_len in enumerate(seq_lens): for i, seq_len in enumerate(seq_lens):
end = start + seq_len end = start + seq_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment