Unverified Commit b70ae096 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Prefix caching (#2402)



* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------
Co-authored-by: default avatarDaniël de Kok <me@danieldk.eu>
parent 38773453
......@@ -316,10 +316,15 @@ impl State {
+ self.speculate
- 1;
match block_allocator
.allocate(tokens, entry.request.input_ids.clone())
.await
{
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
None
} else {
entry.request.input_ids.clone()
};
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
......
......@@ -205,6 +205,7 @@ pub struct RadixTrie {
/// call that a real time lookup would require.
time: u64,
}
impl Default for RadixTrie {
fn default() -> Self {
Self::new()
......
......@@ -900,11 +900,11 @@
]
},
"locked": {
"lastModified": 1723515680,
"narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=",
"lastModified": 1723602049,
"narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3",
"rev": "ea0bf33a11a26a62c60123c49d96011da396602c",
"type": "github"
},
"original": {
......
......@@ -84,6 +84,7 @@
grpcio-status
grpcio-tools
hf-transfer
ipdb
loguru
mamba-ssm
marlin-kernels
......
......@@ -6,7 +6,12 @@ from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
......
......@@ -76,7 +76,7 @@ def paged_attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import decode_state
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
query.contiguous(),
......@@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q,
k,
v,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
......@@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flash_infer import prefill_state
assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
return prefill_state.get().forward(
q,
k,
v,
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
causal=causal,
window_left=window_size_left,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
......@@ -249,6 +252,8 @@ elif V2:
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
......@@ -289,6 +294,8 @@ else:
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
softmax_scale,
......
......@@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
"prefill_state"
)
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
......@@ -24,6 +28,78 @@ def get_workspace(device):
return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state(
*,
device: torch.device,
......
......@@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
)
def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
......
......@@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
......@@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
self.softmax_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment