Unverified Commit 59ea38cb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Simplify the `attention` function (#2609)

* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
parent 5bbe1ce0
...@@ -38,7 +38,6 @@ from text_generation_server.layers import ( ...@@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
...@@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv_to_cache[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment