Unverified Commit b70ae096 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Prefix caching (#2402)



* Prefix caching WIP

* Fixing prefix attention.

* Fixing flashinfer import.

* Fixing black.

* Fixing medusa (still wrong outputs, but functional).

* Just medusa values now.

* Fixing medusa without prefix caching.

* Fixing prefix caching.

* Medusa requires reshaping.

* Removing the logs.

* Remove router.nix

* Fixup:

- Remove logs
- Disable VLMs (they do not work)
- Disable prefix caching when user wants prefill logprobs.

* Update flake.lock

---------
Co-authored-by: default avatarDaniël de Kok <me@danieldk.eu>
parent 38773453
...@@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module): ...@@ -208,6 +208,8 @@ class FlashRWAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
...@@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module): ...@@ -326,6 +328,8 @@ class FlashRWLargeAttention(torch.nn.Module):
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module): ...@@ -293,6 +293,8 @@ class FlashMQAttention(torch.nn.Module):
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module): ...@@ -242,6 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
......
...@@ -43,6 +43,7 @@ from text_generation_server.models.globals import ( ...@@ -43,6 +43,7 @@ from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
PREFIX_CACHING,
get_adapter_to_index, get_adapter_to_index,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
...@@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch): ...@@ -138,6 +139,9 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_seqlen: int
...@@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch): ...@@ -146,6 +150,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices: Optional[torch.tensor] prefill_next_token_indices: Optional[torch.tensor]
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Prefixes
prefix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor all_input_ids_tensor: torch.Tensor
...@@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch): ...@@ -213,6 +220,7 @@ class FlashCausalLMBatch(Batch):
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True all_prefill_logprobs = True
...@@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch): ...@@ -230,7 +238,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
cumulative_max_length = 0 cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
num_blocks = 0 num_blocks = 0
...@@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch): ...@@ -240,6 +248,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
slots = [] slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
...@@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch): ...@@ -255,6 +264,19 @@ class FlashCausalLMBatch(Batch):
): ):
tokenized_input = tokenized_input[1:] tokenized_input = tokenized_input[1:]
orig_input_length = len(tokenized_input)
if PREFIX_CACHING:
prefix_len = r.prefix_len
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
else:
prefix_len = 0
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
...@@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch): ...@@ -264,7 +286,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32) request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
...@@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch): ...@@ -288,11 +312,17 @@ class FlashCausalLMBatch(Batch):
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
...@@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch): ...@@ -303,16 +333,20 @@ class FlashCausalLMBatch(Batch):
] ]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
slots.extend(request_slots)
prefix_lens.append(prefix_len)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
cumulative_max_length, cumulative_slot_tokens,
cumulative_max_length + input_length, cumulative_slot_tokens + input_length,
dtype=torch.int64, dtype=torch.int64,
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
...@@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch): ...@@ -348,7 +382,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
...@@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch): ...@@ -425,12 +459,14 @@ class FlashCausalLMBatch(Batch):
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device) slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
...@@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch): ...@@ -445,6 +481,8 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
...@@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch): ...@@ -455,6 +493,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
...@@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch): ...@@ -510,8 +549,10 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_lens = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
...@@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch): ...@@ -533,11 +574,14 @@ class FlashCausalLMBatch(Batch):
# Get length # Get length
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
prefix_len = self.prefix_lens[idx]
max_seqlen = max(max_seqlen, request_input_length) max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_lens.append(prefix_len)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
...@@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch): ...@@ -582,6 +626,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
prefix_lens_tensor = self.prefix_lens_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = ( speculative_ids = (
...@@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch): ...@@ -617,10 +662,13 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens=None, prefill_cu_outlens=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
...@@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch): ...@@ -681,6 +729,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
...@@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch): ...@@ -698,7 +747,9 @@ class FlashCausalLMBatch(Batch):
start_slots = [] start_slots = []
block_tables = [] block_tables = []
prefix_lens = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
...@@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch): ...@@ -760,10 +811,14 @@ class FlashCausalLMBatch(Batch):
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
...@@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch): ...@@ -809,6 +864,8 @@ class FlashCausalLMBatch(Batch):
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=None, prefill_head_indices=None,
...@@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch): ...@@ -820,6 +877,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
...@@ -970,19 +1028,22 @@ class FlashCausalLM(Model): ...@@ -970,19 +1028,22 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
create_prefill_state, create_prefill_state,
create_decode_state, create_decode_state,
create_prefill_with_paged_kv_state,
) )
self.prefill_state = create_prefill_state(device=device) self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
if not CUDA_GRAPHS: self.decode_state = create_decode_state(
self.decode_state = create_decode_state( device=device,
device=device, num_heads=self.num_heads,
num_heads=self.num_heads, num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_kv_heads, )
)
super().__init__( super().__init__(
model_id=model_id, model_id=model_id,
...@@ -1074,12 +1135,23 @@ class FlashCausalLM(Model): ...@@ -1074,12 +1135,23 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_lengths = [max_s] * bs
block_tables = ( prefix_lengths = [0] * bs
torch.arange(max_bt, dtype=torch.int32, device=self.device) input_lengths_tensor = (
.repeat(bs) torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
.reshape((bs, max_bt))
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
)
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
...@@ -1087,14 +1159,14 @@ class FlashCausalLM(Model): ...@@ -1087,14 +1159,14 @@ class FlashCausalLM(Model):
"kv_cache": self.kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths) input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )
...@@ -1104,7 +1176,7 @@ class FlashCausalLM(Model): ...@@ -1104,7 +1176,7 @@ class FlashCausalLM(Model):
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs( state = create_decode_state_cuda_graphs(
device=input_ids.device, device=input_ids.device,
block_tables=block_tables.view(-1), block_tables=block_tables,
block_tables_ptr=block_tables_ptr, block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len, last_page_len=last_page_len,
num_heads=self.num_heads, num_heads=self.num_heads,
...@@ -1120,7 +1192,10 @@ class FlashCausalLM(Model): ...@@ -1120,7 +1192,10 @@ class FlashCausalLM(Model):
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens=prefix_lengths,
prefix_lens_tensor=prefix_lengths_tensor,
): ):
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -1138,7 +1213,7 @@ class FlashCausalLM(Model): ...@@ -1138,7 +1213,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1146,7 +1221,7 @@ class FlashCausalLM(Model): ...@@ -1146,7 +1221,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
...@@ -1334,6 +1409,9 @@ class FlashCausalLM(Model): ...@@ -1334,6 +1409,9 @@ class FlashCausalLM(Model):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = ( block_tables = (
...@@ -1354,6 +1432,7 @@ class FlashCausalLM(Model): ...@@ -1354,6 +1432,7 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
...@@ -1372,10 +1451,20 @@ class FlashCausalLM(Model): ...@@ -1372,10 +1451,20 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
...@@ -1399,20 +1488,32 @@ class FlashCausalLM(Model): ...@@ -1399,20 +1488,32 @@ class FlashCausalLM(Model):
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][ if ATTENTION == "flashinfer":
: block_tables.shape[0], : block_tables.shape[1] block_tables = block_tables_to_ragged(
] = block_tables block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
state = cuda_graph.get("state")
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
state=state, input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
state=cuda_graph.get("state"),
): ):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
...@@ -1610,6 +1711,7 @@ class FlashCausalLM(Model): ...@@ -1610,6 +1711,7 @@ class FlashCausalLM(Model):
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
...@@ -1627,6 +1729,7 @@ class FlashCausalLM(Model): ...@@ -1627,6 +1729,7 @@ class FlashCausalLM(Model):
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
...@@ -1701,18 +1804,18 @@ class FlashCausalLM(Model): ...@@ -1701,18 +1804,18 @@ class FlashCausalLM(Model):
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ request_prefill_logprobs = (
out_start_index : out_end_index - 1 [float("nan")] * (len(prefix_ids) + 1)
] ) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefix_ids + prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefix_ids + prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],
...@@ -1794,33 +1897,68 @@ class FlashCausalLM(Model): ...@@ -1794,33 +1897,68 @@ class FlashCausalLM(Model):
*, *,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths: torch.Tensor, input_lengths: List[int],
input_lengths_tensor: torch.Tensor,
prefix_lens: List[int],
prefix_lens_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
return nullcontext() return nullcontext()
from text_generation_server.layers.attention.flash_infer import ( from text_generation_server.layers.attention.flashinfer import (
use_decode_state, use_decode_state,
use_prefill_state, use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_state( return use_prefill_with_paged_kv_state(
state=state if state is not None else self.prefill_state, state=(
state if state is not None else self.prefill_with_paged_kv_state
),
# block_tables=block_tables_to_ragged(
# block_tables=block_tables,
# input_lengths=input_lengths,
# prefix_lens=prefix_lens,
# ),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE,
) )
else: else:
assert input_lengths is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
block_tables=block_tables.view(-1), block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
) )
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens)
total_len = sum(input_lengths) + sum(prefix_lens)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
seq_len = prefix_len + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged
...@@ -5,9 +5,8 @@ from typing import Dict, Optional ...@@ -5,9 +5,8 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using Attention = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
...@@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer": ...@@ -29,7 +28,6 @@ elif ATTENTION == "flashinfer":
else: else:
BLOCK_SIZE = 16 BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
......
...@@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2 ...@@ -11,7 +11,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch, FlashCausalLMBatch,
FlashCausalLM, FlashCausalLM,
block_tables_to_ragged,
) )
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
...@@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM): ...@@ -254,6 +256,8 @@ class VlmCausalLM(FlashCausalLM):
trust_remote_code: bool, trust_remote_code: bool,
**kwargs, **kwargs,
): ):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None: if processor_kwargs is None:
processor_kwargs = {} processor_kwargs = {}
self.processor = processor_class.from_pretrained( self.processor = processor_class.from_pretrained(
...@@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM): ...@@ -310,6 +314,9 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = ( block_tables = (
...@@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM): ...@@ -330,6 +337,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
...@@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM): ...@@ -349,43 +357,68 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths) input_lengths = input_lengths + prefix_lens_tensor
logits, speculative_logits = self.model.forward( if PREFIX_CACHING:
input_ids=input_ids, block_tables = block_tables_to_ragged(
position_ids=position_ids, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths,
kv_cache=kv_cache, prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
slots=slots, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths=input_lengths, input_lengths=batch.input_lengths,
max_s=max_s, input_lengths_tensor=input_lengths,
prefill_cache_indices=batch.prefill_cache_indices, prefix_lens=batch.prefix_lens,
lm_head_indices=lm_head_indices, prefix_lens_tensor=prefix_lens_tensor,
pixel_values=batch.pixel_values, ):
pixel_attention_mask=batch.pixel_attention_mask, input_lengths = Seqlen(input_lengths=input_lengths)
image_sizes=batch.image_sizes, logits, speculative_logits = self.model.forward(
) input_ids=input_ids,
if batch.prefill_cache_indices is not None: position_ids=position_ids,
batch.prefill_cache_indices = None cu_seqlen_prefill=cu_seqlen_prefill,
if batch.pixel_values is not None: kv_cache=kv_cache,
batch.pixel_values = None block_tables=block_tables,
if batch.pixel_attention_mask is not None: slots=slots,
batch.pixel_attention_mask = None input_lengths=input_lengths,
if batch.image_sizes is not None: max_s=max_s,
batch.image_sizes = None prefill_cache_indices=batch.prefill_cache_indices,
return logits, speculative_logits lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][ if ATTENTION == "flashinfer":
: block_tables.shape[0], : block_tables.shape[1] block_tables = block_tables_to_ragged(
] = block_tables block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1) cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment