Unverified Commit b70ae096 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Prefix caching (#2402)



* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------
Co-authored-by: default avatarDaniël de Kok <me@danieldk.eu>
parent 38773453
...@@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): ...@@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): ...@@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -5,9 +5,8 @@ from typing import Dict, Optional ...@@ -5,9 +5,8 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
...@@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": ...@@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer":
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
......
...@@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 ...@@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
block_tables_to_ragged,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
...@@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): ...@@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
trust_remote_code: bool, trust_remote_code: bool,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None: if processor_kwargs is None:
processor_kwargs = {} processor_kwargs = {}
self.processor = processor_class.from_pretrained( self.processor = processor_class.from_pretrained(
...@@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): ...@@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = ( block_tables = (
...@@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): ...@@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
...@@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM): ...@@ -349,6 +357,21 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM): ...@@ -379,13 +402,23 @@ class VlmCausalLM(FlashCausalLM):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment