"vscode:/vscode.git/clone" did not exist on "a23f0158bb1aeb2c4078a032647c51f03c03a166"
Unverified Commit 59ea38cb authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Simplify the `attention` function (#2609)

* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
parent 5bbe1ce0
......@@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
......@@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment