"docs/vscode:/vscode.git/clone" did not exist on "42bb201fd6f79d6ed2e28e0263ffa891cd993c4c"
Unverified Commit 6d097697 authored by Micah Williamson's avatar Micah Williamson Committed by GitHub
Browse files

[ROCm] Support non-causal attention in ROCM_ATTN (#40176)


Signed-off-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 4506319a
......@@ -13,10 +13,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
RocmAttentionImpl,
RocmAttentionMetadata,
RocmAttentionMetadataBuilder,
)
......@@ -53,6 +53,10 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_non_causal(cls) -> bool:
return False
forward_includes_kv_cache_update: bool = False
@staticmethod
......@@ -140,7 +144,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata: RocmAttentionMetadata,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
......
......@@ -27,7 +27,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
......@@ -69,6 +68,9 @@ class RocmAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
# DFlash drafting sets this to False via CommonAttentionMetadata.
causal: bool = True
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
......@@ -154,6 +156,7 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
causal=common_attn_metadata.causal,
)
return attn_metadata
......@@ -200,6 +203,10 @@ class RocmAttentionBackend(AttentionBackend):
# kernel, which is less efficient than the proper triton backends.
return False
@classmethod
def supports_non_causal(cls) -> bool:
return True
forward_includes_kv_cache_update: bool = False
@staticmethod
......@@ -301,7 +308,7 @@ class RocmAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata: RocmAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
......@@ -350,7 +357,7 @@ class RocmAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata: RocmAttentionMetadata,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
......@@ -438,6 +445,7 @@ class RocmAttentionImpl(AttentionImpl):
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
causal=attn_metadata.causal,
)
return output
......
......@@ -269,6 +269,7 @@ def chunked_prefill_paged_decode(
# Optional tensor for sinks
sinks=None,
is_block_table_ptr: bool = False,
causal: bool = True,
):
if sm_scale is None:
sm_scale = 1.0 / (query.shape[2] ** 0.5)
......@@ -300,6 +301,7 @@ def chunked_prefill_paged_decode(
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
causal=causal,
)
block_size = value_cache.shape[3]
......
......@@ -89,6 +89,7 @@ def _fwd_kernel(
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
CAUSAL: tl.constexpr = True,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
......@@ -283,10 +284,17 @@ def _fwd_kernel(
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
# compute query against itself (causal among queries by default;
# CAUSAL=False for bidirectional attention over query tokens, e.g. DFlash.)
if CAUSAL:
key_range_upper = block_mask * (start_m + 1) * BLOCK_M
else:
q_len_pad = (cur_batch_query_len + BLOCK_N - 1) // BLOCK_N * BLOCK_N
key_range_upper = block_mask * q_len_pad
for start_n in tl.range(
0,
block_mask * (start_m + 1) * BLOCK_M,
key_range_upper,
BLOCK_N,
loop_unroll_factor=num_unroll_request,
):
......@@ -302,14 +310,17 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
valid_kv = (start_n + offs_n[None, :]) < cur_batch_query_len
if CAUSAL:
attn_mask = valid_kv & (offs_m[:, None] >= (start_n + offs_n[None, :]))
else:
attn_mask = valid_kv
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk,
float("-inf"),
attn_mask = attn_mask & (
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW
)
qk = tl.where(attn_mask, qk, float("-inf"))
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
......@@ -656,6 +667,7 @@ def context_attention_fwd(
fp8_out_scale=None,
sinks=None,
is_block_table_ptr: bool = False,
causal: bool = True,
):
q_dtype_is_f32 = q.dtype is torch.float32
......@@ -722,6 +734,7 @@ def context_attention_fwd(
processed_b_loc = b_loc.to(torch.int32)
if alibi_slopes is not None:
assert causal, "Non-causal prefix attention is not supported with alibi"
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
......@@ -859,6 +872,7 @@ def context_attention_fwd(
num_warps=4,
num_stages=1,
USE_SINKS=sinks is not None,
CAUSAL=causal,
**extra_kargs,
)
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment