Unverified Commit 6d097697 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm] Support non-causal attention in ROCM_ATTN (#40176)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 4506319a
...@@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
) )
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend, RocmAttentionBackend,
RocmAttentionImpl, RocmAttentionImpl,
RocmAttentionMetadata,
RocmAttentionMetadataBuilder, RocmAttentionMetadataBuilder,
) )
...@@ -53,6 +53,10 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): ...@@ -53,6 +53,10 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def supports_sink(cls) -> bool: def supports_sink(cls) -> bool:
return True return True
@classmethod
def supports_non_causal(cls) -> bool:
return False
forward_includes_kv_cache_update: bool = False forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
...@@ -140,7 +144,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -140,7 +144,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: RocmAttentionMetadata,
output: torch.Tensor, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
......
...@@ -27,7 +27,6 @@ from vllm.v1.attention.backend import ( ...@@ -27,7 +27,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata, CommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode, chunked_prefill_paged_decode,
) )
...@@ -69,6 +68,9 @@ class RocmAttentionMetadata: ...@@ -69,6 +68,9 @@ class RocmAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None
# DFlash drafting sets this to False via CommonAttentionMetadata.
causal: bool = True
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
...@@ -154,6 +156,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat ...@@ -154,6 +156,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
prefix_kv_lens=prefix_kv_lens, prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
causal=common_attn_metadata.causal,
) )
return attn_metadata return attn_metadata
...@@ -200,6 +203,10 @@ class RocmAttentionBackend(AttentionBackend): ...@@ -200,6 +203,10 @@ class RocmAttentionBackend(AttentionBackend):
# kernel, which is less efficient than the proper triton backends. # kernel, which is less efficient than the proper triton backends.
return False return False
@classmethod
def supports_non_causal(cls) -> bool:
return True
forward_includes_kv_cache_update: bool = False forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
...@@ -301,7 +308,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -301,7 +308,7 @@ class RocmAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: RocmAttentionMetadata,
layer: torch.nn.Module, layer: torch.nn.Module,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache. """Forward pass for encoder attention without KV cache.
...@@ -350,7 +357,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -350,7 +357,7 @@ class RocmAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: RocmAttentionMetadata,
output: torch.Tensor, output: torch.Tensor,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
...@@ -438,6 +445,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -438,6 +445,7 @@ class RocmAttentionImpl(AttentionImpl):
sm_scale=self.scale, sm_scale=self.scale,
output_scale=output_scale, output_scale=output_scale,
sinks=self.sinks, sinks=self.sinks,
causal=attn_metadata.causal,
) )
return output return output
......
...@@ -269,6 +269,7 @@ def chunked_prefill_paged_decode( ...@@ -269,6 +269,7 @@ def chunked_prefill_paged_decode(
# Optional tensor for sinks # Optional tensor for sinks
sinks=None, sinks=None,
is_block_table_ptr: bool = False, is_block_table_ptr: bool = False,
causal: bool = True,
): ):
if sm_scale is None: if sm_scale is None:
sm_scale = 1.0 / (query.shape[2] ** 0.5) sm_scale = 1.0 / (query.shape[2] ** 0.5)
...@@ -300,6 +301,7 @@ def chunked_prefill_paged_decode( ...@@ -300,6 +301,7 @@ def chunked_prefill_paged_decode(
skip_decode=True, skip_decode=True,
fp8_out_scale=output_scale, fp8_out_scale=output_scale,
sinks=sinks, sinks=sinks,
causal=causal,
) )
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
......
...@@ -89,6 +89,7 @@ def _fwd_kernel( ...@@ -89,6 +89,7 @@ def _fwd_kernel(
SKIP_DECODE: tl.constexpr, SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr, USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr, USE_FP8: tl.constexpr,
CAUSAL: tl.constexpr = True,
MAX_Q_LEN: tl.constexpr = 0, MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0, MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min, FP8_MIN: tl.constexpr = float8_info.min,
...@@ -283,10 +284,17 @@ def _fwd_kernel( ...@@ -283,10 +284,17 @@ def _fwd_kernel(
# block_mask is 0 when we're already past the current query length # block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask) # compute query against itself (causal among queries by default;
# CAUSAL=False for bidirectional attention over query tokens, e.g. DFlash.)
if CAUSAL:
key_range_upper = block_mask * (start_m + 1) * BLOCK_M
else:
q_len_pad = (cur_batch_query_len + BLOCK_N - 1) // BLOCK_N * BLOCK_N
key_range_upper = block_mask * q_len_pad
for start_n in tl.range( for start_n in tl.range(
0, 0,
block_mask * (start_m + 1) * BLOCK_M, key_range_upper,
BLOCK_N, BLOCK_N,
loop_unroll_factor=num_unroll_request, loop_unroll_factor=num_unroll_request,
): ):
...@@ -302,14 +310,17 @@ def _fwd_kernel( ...@@ -302,14 +310,17 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) valid_kv = (start_n + offs_n[None, :]) < cur_batch_query_len
if CAUSAL:
attn_mask = valid_kv & (offs_m[:, None] >= (start_n + offs_n[None, :]))
else:
attn_mask = valid_kv
if SLIDING_WINDOW > 0: if SLIDING_WINDOW > 0:
qk = tl.where( attn_mask = attn_mask & (
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW
qk,
float("-inf"),
) )
qk = tl.where(attn_mask, qk, float("-inf"))
# compute running maximum # compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
...@@ -656,6 +667,7 @@ def context_attention_fwd( ...@@ -656,6 +667,7 @@ def context_attention_fwd(
fp8_out_scale=None, fp8_out_scale=None,
sinks=None, sinks=None,
is_block_table_ptr: bool = False, is_block_table_ptr: bool = False,
causal: bool = True,
): ):
q_dtype_is_f32 = q.dtype is torch.float32 q_dtype_is_f32 = q.dtype is torch.float32
...@@ -722,6 +734,7 @@ def context_attention_fwd( ...@@ -722,6 +734,7 @@ def context_attention_fwd(
processed_b_loc = b_loc.to(torch.int32) processed_b_loc = b_loc.to(torch.int32)
if alibi_slopes is not None: if alibi_slopes is not None:
assert causal, "Non-causal prefix attention is not supported with alibi"
assert sinks is None, "Sinks arg is not supported with alibi" assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi" assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
...@@ -859,6 +872,7 @@ def context_attention_fwd( ...@@ -859,6 +872,7 @@ def context_attention_fwd(
num_warps=4, num_warps=4,
num_stages=1, num_stages=1,
USE_SINKS=sinks is not None, USE_SINKS=sinks is not None,
CAUSAL=causal,
**extra_kargs, **extra_kargs,
) )
return return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment