"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "cdf65e793f96c57b64ec7523b0b3f2d6e7b9e9e6"
Unverified Commit 47447ef0 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Unify attention output handling (#2343)

- Always return the hidden states.
- Create the output tensor inside the `attention` and `paged_attention`
  functions.

This removes the difference between how the output is handled between
attention (output parameter) and paged attention (return value). This
also removes the assumption that the attention implementation can
write to an output tensor (in preparation of FlashInfer).
parent 22fb1be5
...@@ -34,7 +34,6 @@ def reshape_and_cache( ...@@ -34,7 +34,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
...@@ -85,7 +84,7 @@ def paged_attention( ...@@ -85,7 +84,7 @@ def paged_attention(
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None: if softcap is None:
softcap = 0.0 softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd( out = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
value_cache, value_cache,
...@@ -108,13 +107,15 @@ def paged_attention( ...@@ -108,13 +107,15 @@ def paged_attention(
False, # return softmax False, # return softmax
None, # generator None, # generator
) )
return out2[0] return out[0]
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and ( use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
...@@ -200,13 +201,13 @@ except ImportError: ...@@ -200,13 +201,13 @@ except ImportError:
SUPPORTS_WINDOWING = V2 SUPPORTS_WINDOWING = V2
if V2: if V2:
def attention( def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
...@@ -214,6 +215,7 @@ if V2: ...@@ -214,6 +215,7 @@ if V2:
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
...@@ -238,7 +240,7 @@ if V2: ...@@ -238,7 +240,7 @@ if V2:
softcap, softcap,
False, False,
None, None,
) )[0]
else: else:
...@@ -246,7 +248,6 @@ else: ...@@ -246,7 +248,6 @@ else:
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
...@@ -286,6 +287,8 @@ else: ...@@ -286,6 +287,8 @@ else:
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
out = torch.empty_like(q)
return flash_attn_cuda.fwd( return flash_attn_cuda.fwd(
q, q,
k, k,
...@@ -302,4 +305,4 @@ else: ...@@ -302,4 +305,4 @@ else:
False, False,
0, 0,
None, None,
) )[0]
...@@ -10,13 +10,14 @@ def attention( ...@@ -10,13 +10,14 @@ def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
q, q,
...@@ -49,7 +50,6 @@ def reshape_and_cache( ...@@ -49,7 +50,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
...@@ -59,6 +59,7 @@ def paged_attention( ...@@ -59,6 +59,7 @@ def paged_attention(
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
......
...@@ -39,7 +39,6 @@ def reshape_and_cache( ...@@ -39,7 +39,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
...@@ -72,6 +71,8 @@ def paged_attention( ...@@ -72,6 +71,8 @@ def paged_attention(
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths input_lengths = input_lengths.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
...@@ -174,7 +175,6 @@ if ENGINE == "ck": ...@@ -174,7 +175,6 @@ if ENGINE == "ck":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
...@@ -184,6 +184,8 @@ if ENGINE == "ck": ...@@ -184,6 +184,8 @@ if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
...@@ -209,13 +211,14 @@ elif ENGINE == "triton": ...@@ -209,13 +211,14 @@ elif ENGINE == "triton":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, q,
......
...@@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module): ...@@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module): ...@@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module): ...@@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module): ...@@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module): ...@@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module): ...@@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module): ...@@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module): ...@@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module): ...@@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module): ...@@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module): ...@@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
...@@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
slots, slots,
) )
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module): ...@@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module):
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module): ...@@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
...@@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment