Unverified Commit 47447ef0 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Unify attention output handling (#2343)

- Always return the hidden states.
- Create the output tensor inside the `attention` and `paged_attention`
  functions.

This removes the difference between how the output is handled between
attention (output parameter) and paged attention (return value). This
also removes the assumption that the attention implementation can
write to an output tensor (in preparation of FlashInfer).
parent 22fb1be5
......@@ -34,7 +34,6 @@ def reshape_and_cache(
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
......@@ -85,7 +84,7 @@ def paged_attention(
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd(
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
......@@ -108,13 +107,15 @@ def paged_attention(
False, # return softmax
None, # generator
)
return out2[0]
return out[0]
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
......@@ -200,13 +201,13 @@ except ImportError:
SUPPORTS_WINDOWING = V2
if V2:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
......@@ -214,6 +215,7 @@ if V2:
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
......@@ -238,7 +240,7 @@ if V2:
softcap,
False,
None,
)
)[0]
else:
......@@ -246,7 +248,6 @@ else:
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
......@@ -286,6 +287,8 @@ else:
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
return flash_attn_cuda.fwd(
q,
k,
......@@ -302,4 +305,4 @@ else:
False,
0,
None,
)
)[0]
......@@ -10,13 +10,14 @@ def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
q,
......@@ -49,7 +50,6 @@ def reshape_and_cache(
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
......@@ -59,6 +59,7 @@ def paged_attention(
seqlen: Seqlen,
max_s: int,
):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
......
......@@ -39,7 +39,6 @@ def reshape_and_cache(
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
......@@ -72,6 +71,8 @@ def paged_attention(
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
......@@ -174,7 +175,6 @@ if ENGINE == "ck":
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
......@@ -184,6 +184,8 @@ if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
......@@ -209,13 +211,14 @@ elif ENGINE == "triton":
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
......
......@@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention(
attn_output,
attn_output = paged_attention(
query,
kv_cache[0],
kv_cache[1],
......
......@@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
......
......@@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
# Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......@@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
slots,
)
# output
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module):
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
# output
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
......@@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......@@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment