Unverified Commit b70ae096 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Prefix caching (#2402)



* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------
Co-authored-by: default avatarDaniël de Kok <me@danieldk.eu>
parent 38773453
...@@ -316,10 +316,15 @@ impl State { ...@@ -316,10 +316,15 @@ impl State {
+ self.speculate + self.speculate
- 1; - 1;
match block_allocator // If users wants the prefill logprobs, we cannot reuse the cache.
.allocate(tokens, entry.request.input_ids.clone()) // So no input_ids for the radix tree.
.await let input_ids = if entry.request.decoder_input_details {
{ None
} else {
entry.request.input_ids.clone()
};
match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
......
...@@ -205,6 +205,7 @@ pub struct RadixTrie { ...@@ -205,6 +205,7 @@ pub struct RadixTrie {
/// call that a real time lookup would require. /// call that a real time lookup would require.
time: u64, time: u64,
} }
impl Default for RadixTrie { impl Default for RadixTrie {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
......
...@@ -900,11 +900,11 @@ ...@@ -900,11 +900,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1723515680, "lastModified": 1723602049,
"narHash": "sha256-nHdKymsHCVIh0Wdm4MvSgxcTTg34FJIYHRQkQYaSuvk=", "narHash": "sha256-Z/noCSn9WPkv7O77dWKLcBxe4Ub4bWyNzsL5JhjaQfw=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "4ee3d9e9569f70d7bb40f28804d6fe950c81eab3", "rev": "ea0bf33a11a26a62c60123c49d96011da396602c",
"type": "github" "type": "github"
}, },
"original": { "original": {
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
grpcio-status grpcio-status
grpcio-tools grpcio-tools
hf-transfer hf-transfer
ipdb
loguru loguru
mamba-ssm mamba-ssm
marlin-kernels marlin-kernels
......
...@@ -6,7 +6,12 @@ from .common import Seqlen ...@@ -6,7 +6,12 @@ from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
......
...@@ -76,7 +76,7 @@ def paged_attention( ...@@ -76,7 +76,7 @@ def paged_attention(
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import decode_state from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward( return decode_state.get().forward(
query.contiguous(), query.contiguous(),
...@@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2 ...@@ -221,9 +221,11 @@ SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
def attention( def attention(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
...@@ -231,14 +233,15 @@ if ATTENTION == "flashinfer": ...@@ -231,14 +233,15 @@ if ATTENTION == "flashinfer":
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
from text_generation_server.layers.attention.flash_infer import prefill_state assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
return prefill_state.get().forward( return prefill_with_paged_kv_state.get().forward(
q, q.contiguous(),
k,
v,
causal=causal, causal=causal,
window_left=window_size_left, paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
...@@ -249,6 +252,8 @@ elif V2: ...@@ -249,6 +252,8 @@ elif V2:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
...@@ -289,6 +294,8 @@ else: ...@@ -289,6 +294,8 @@ else:
q, q,
k, k,
v, v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
......
...@@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con ...@@ -9,6 +9,10 @@ prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = Con
"prefill_state" "prefill_state"
) )
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar( decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state" "decode_state"
) )
...@@ -24,6 +28,78 @@ def get_workspace(device): ...@@ -24,6 +28,78 @@ def get_workspace(device):
return workspace return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
query_dtype: str = "float16",
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=query_dtype,
page_size=page_size,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state( def create_prefill_state(
*, *,
device: torch.device, device: torch.device,
......
...@@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ...@@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
) )
def forward(self, x): def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits return speculative_logits
......
...@@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module): ...@@ -298,6 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module): ...@@ -337,6 +337,8 @@ class DbrxAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module): ...@@ -365,6 +365,8 @@ class DeepseekV2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module): ...@@ -238,6 +238,8 @@ class FlashGemma2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module): ...@@ -232,6 +232,8 @@ class FlashGemmaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module): ...@@ -232,6 +232,8 @@ class FlashGPT2Attention(torch.nn.Module):
query, query,
key, key,
value, value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -220,6 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module): ...@@ -219,6 +219,8 @@ class MistralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module): ...@@ -276,6 +276,8 @@ class MixtralAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -173,6 +173,8 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module): ...@@ -194,6 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module): ...@@ -137,6 +137,8 @@ class Qwen2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment