"docs/MODEL_ZOO.md" did not exist on "48647f794b2dbf641f98df22cd17bab8b8afe8a9"
Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
import pytest
import base64
import asyncio
......@@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
......
......@@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
......@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
}
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
......@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
......@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
......@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
......
......@@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
}
/// Empty request
......@@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
}
message Batch {
......@@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {
......@@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
......@@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message DecodeRequest {
......
......@@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
......
......@@ -2,7 +2,7 @@ import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
......
......@@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept(
self,
method: Callable,
......@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
self.shutdown_callback()
if torch.cuda.is_available():
torch.cuda.empty_cache()
......
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
......@@ -123,7 +123,7 @@ def paged_attention(
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops
out = torch.empty_like(query)
......@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer":
window_left=window_size_left,
)
elif V2:
elif ATTENTION == "flashdecoding":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
def attention(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [
"PREFILL_IN_KV_CACHE",
......
......@@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
......
......@@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
......@@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
input_lengths,
BLOCK_SIZE,
max_s,
None,
......
......@@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query)
......
......@@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self.input_ids),
)
@classmethod
......
......@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
pixel_values.shape
)
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
......
......@@ -16,7 +16,17 @@ from transformers import (
AutoTokenizer,
GenerationConfig,
)
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict
from typing import (
Any,
ContextManager,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
......@@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import (
get_support_chunking,
get_max_prefill_tokens,
)
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
......@@ -60,7 +74,6 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
......@@ -117,45 +130,48 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping: Dict[int, int]
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
# Can be a list for easy filtering
# If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor]
speculative_ids: Optional[torch.Tensor]
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Paged Attention values
# Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slots: Optional[torch.Tensor]
max_seqlen: int
max_input_length: int
max_current_length: int
# Whether this batch contains at least one request that is prefilling
prefilling: bool
# Whether each request is prefilling
prefilling_mask: List[bool]
# Prefill metadata tensors to efficiently compute logprobs
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_head_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_next_token_indices: Optional[torch.tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]]
# Prefixes
prefix_ids: List[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[Tokens]]
# All tokens
all_input_ids: List[List[int]]
......@@ -163,7 +179,14 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
input_lengths_tensor: torch.Tensor
# size [b], containing the number of blocks that can be retrieved from the cache
cache_lengths: List[int]
prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
input_lengths_tensor: Optional[torch.Tensor]
cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
......@@ -174,7 +197,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request
adapter_meta: AdapterBatchMetadata
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
adapter_meta: Optional[AdapterBatchMetadata]
# Number of blocks in this batch
num_blocks: int
......@@ -187,6 +211,11 @@ class FlashCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
)
@classmethod
......@@ -218,46 +247,28 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
start_slots = []
slot_indices = []
prefill_cache_indices = []
speculate = get_speculate()
cache_lengths = []
input_lengths = []
prompt_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
prefix_ids = []
all_postfix_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
num_blocks = 0
max_seqlen = 0
max_input_length = 0
max_current_length = 0
max_length = 0
max_blocks = 0
block_tables = []
slots = []
prefix_lens = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
......@@ -266,38 +277,47 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
orig_input_length = len(tokenized_input)
prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = r.cache_len
prefix_len = r.prefix_len
assert (
prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
input_length = len(tokenized_input)
cache_length <= prompt_length
), f"Prefix {cache_length} vs input {prompt_length}"
if cache_length == prompt_length:
assert False, "unreachable"
# `chunk_len` is an optional field in the protobuf
# It is only set if the model support chunking
if r.HasField("chunk_len"):
input_length = r.chunk_len
if cache_length + input_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
assert get_support_chunking()
assert input_length > 0
postfix_ids = tokenized_input[
cache_length : cache_length + input_length
]
assert (
len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned"
else:
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
prefix_offsets.append(prompt_length - 5)
read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
......@@ -307,22 +327,13 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup)
if not r.blocks:
......@@ -330,77 +341,26 @@ class FlashCausalLMBatch(Batch):
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks)
slots.extend(request_slots)
prefix_lens.append(prefix_len)
cache_lengths.append(cache_length)
num_blocks += len(request_blocks)
start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks))
max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, cache_length + input_length)
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
max_length,
prompt_length + max_new_tokens + speculative_length,
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
......@@ -414,103 +374,59 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor, dtype=torch.int64, device=device
)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
prefill_cache_indices=prefill_cache_indices,
start_slots=start_slots,
slot_indices=slot_indices,
input_ids=all_postfix_ids,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
cache_lengths=cache_lengths,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=True,
prefilling_mask=[True] * len(pb.requests),
prefill_logprob_tokens=[None] * len(pb.requests),
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks,
max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=None,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
cache_lengths_tensor=None,
input_lengths_tensor=None,
adapter_meta=None,
)
@classmethod
......@@ -533,7 +449,7 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self):
return self
device = self.input_ids.device
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
......@@ -548,19 +464,23 @@ class FlashCausalLMBatch(Batch):
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_seqlen = 0
max_input_length = 0
max_current_length = 0
requests = []
start_slots = []
block_tables = []
all_input_ids = []
prefix_ids = []
input_ids = []
prompt_lengths = []
input_lengths = []
prefix_lens = []
cache_lengths = []
prefix_offsets = []
read_offsets = []
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = []
top_n_tokens = []
adapter_set = set()
......@@ -577,16 +497,23 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length
request_input_length = self.input_lengths[idx]
prefix_len = self.prefix_lens[idx]
max_seqlen = max(max_seqlen, request_input_length)
request_cache_length = self.cache_lengths[idx]
max_input_length = max(max_input_length, request_input_length)
max_current_length = max(
max_current_length, request_cache_length + request_input_length
)
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
prefix_lens.append(prefix_len)
cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
......@@ -594,60 +521,79 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length + request_input_length - 1
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Set slice
slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True
# Set slice
slot_filtering_indices[
self.slot_indices[idx] : self.slot_indices[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lens_tensor = self.prefix_lens_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
slots = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return type(self)(
batch_id=self.batch_id,
......@@ -657,24 +603,28 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
max_seqlen=max_seqlen,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
......@@ -682,12 +632,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
adapter_meta=adapter_meta,
)
@classmethod
......@@ -697,74 +642,98 @@ class FlashCausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
prefilling = False
num_blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
max_input_length = 0
max_current_length = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_input_length = max(max_input_length, b.max_input_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
input_length
prompt_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
for prompt_length, stopping_criteria in zip(
b.prompt_lengths, b.stopping_criterias
)
),
)
prefilling = prefilling or b.prefilling
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slots = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = []
prefix_lens = []
cache_lengths = []
all_input_ids = []
prefix_ids = []
prompt_lengths = []
input_lengths = []
prefix_offsets = []
read_offsets = []
prefill_logprob_tokens = []
next_token_chooser_parameters = []
fsm_grammar_states = []
stopping_criterias = []
top_n_tokens = []
prefilling_mask = []
# Cumulative length
cumulative_batch_size = 0
......@@ -783,32 +752,9 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
......@@ -816,20 +762,56 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor
if not prefilling:
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
start_slots.append(batch.start_slots + cumulative_slots)
# Update
cumulative_slots += len(batch.slots)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias)
......@@ -838,11 +820,6 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
......@@ -858,7 +835,14 @@ class FlashCausalLMBatch(Batch):
else None
)
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls(
batch_id=batches[0].batch_id,
......@@ -868,24 +852,28 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
slots=slots,
max_seqlen=max_seqlen,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
......@@ -893,12 +881,195 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
adapter_meta=adapter_meta,
)
def prepare_for_prefill(self):
# Prepare values if we need to continue prefilling
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
slots = []
adapter_indices_list = []
adapter_set = set()
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
blocks,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables,
)
):
next_chunk_length = input_length
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
request_slots = request_slots[cache_length:]
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
adapter_set.add(adapter_index)
# Update
cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots)
device = self.block_tables_tensor.device
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
if len(self) > 1:
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
self.prefill_cu_outlens = prefill_cu_outlens
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
self.cu_seqlen_prefill = cu_seqlen_prefill
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
def __len__(self):
......@@ -938,6 +1109,7 @@ class FlashCausalLM(Model):
head_size: Optional[int] = None,
skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None,
support_chunking: bool = True,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
......@@ -1065,6 +1237,7 @@ class FlashCausalLM(Model):
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
support_chunking=support_chunking,
)
@property
......@@ -1101,11 +1274,11 @@ class FlashCausalLM(Model):
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = [max_s] * bs
prefix_lengths = [0] * bs
cache_lengths = [0] * bs
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
)
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
......@@ -1115,7 +1288,7 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
prefix_lens=prefix_lengths,
cache_lengths=cache_lengths,
)
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
......@@ -1144,7 +1317,7 @@ class FlashCausalLM(Model):
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
"cache_lengths": cache_lengths_tensor,
"state": state,
"graph": graph,
}
......@@ -1156,11 +1329,11 @@ class FlashCausalLM(Model):
cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor,
state=state,
prefix_lens_tensor=prefix_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
......@@ -1184,7 +1357,7 @@ class FlashCausalLM(Model):
with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
......@@ -1207,6 +1380,7 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache()
try:
......@@ -1226,7 +1400,7 @@ class FlashCausalLM(Model):
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
......@@ -1341,14 +1515,16 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.zeros(
seqlen, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
max_s = seqlen
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
......@@ -1380,7 +1556,7 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
......@@ -1399,8 +1575,8 @@ class FlashCausalLM(Model):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
......@@ -1422,10 +1598,12 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
print(slots)
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
......@@ -1445,21 +1623,20 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
......@@ -1486,7 +1663,7 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
......@@ -1501,14 +1678,16 @@ class FlashCausalLM(Model):
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens_tensor=cuda_graph["prefix_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
......@@ -1528,7 +1707,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None
prefill = batch.prefilling
if prefill:
batch.prepare_for_prefill()
prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present)
......@@ -1570,14 +1752,62 @@ class FlashCausalLM(Model):
if prefill_logprobs
else speculative_logits
)
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
else:
prefill_logprobs = None
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
finished_prefilling = True
next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill:
if get_support_chunking():
next_prefilling_mask = []
# Budget in tokens for the next batch
# We remove (len(batch) - 1) to always have enough space for at least a single decode
# for the remaining requests -1 because the first request does not need to be removed from the budget
# (ex: you have one request in the batch, you want it to take the full budget not budget -1)
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
# We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead
for cache_length, input_length, prompt_length in zip(
reversed(batch.cache_lengths),
reversed(batch.input_lengths),
reversed(batch.prompt_lengths),
):
remaining_prefill_tokens = max(
prompt_length - cache_length - input_length, 0
)
if remaining_prefill_tokens > 0:
next_chunk_length = max(
min(remaining_prefill_tokens, batch_budget), 1
)
batch_budget -= next_chunk_length
finished_prefilling = False
next_prefilling_mask.append(True)
else:
# FIXME: use true number of accepted tokens instead of 1
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_prefilling_mask.append(False)
next_chunk_lengths.append(next_chunk_length)
# Reverse back the obtained values²
next_chunk_lengths.reverse()
next_prefilling_mask.reverse()
else:
# The model does not support chunking
# We know we only do a single prefill
finished_prefilling = True
next_prefilling_mask = [False] * len(batch)
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
speculate = get_speculate()
(
next_input_ids,
......@@ -1586,7 +1816,7 @@ class FlashCausalLM(Model):
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
batch.all_input_ids_tensor[:, : batch.max_current_length],
next_token_logits,
speculate,
batch.speculative_ids,
......@@ -1597,29 +1827,28 @@ class FlashCausalLM(Model):
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
else:
prefill_logprobs = None
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
elif not prefill:
next_position_ids = batch.position_ids
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.all_input_ids,
accepted_ids,
current_prefilling_mask,
batch.prefilling_mask,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync
......@@ -1627,16 +1856,22 @@ class FlashCausalLM(Model):
# For each member of the batch
index = 0
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
if prefill:
# Cumulative length
cumulative_length = 0
for i, (
request,
prompt_length,
cache_length,
input_length,
all_input_ids,
n_accepted_ids,
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
if prefill and finished_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index
_start_index = cumulative_length
end_index = cumulative_length + input_length
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
......@@ -1648,41 +1883,43 @@ class FlashCausalLM(Model):
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[
i, cache_length + 1 : cache_length + input_length + 1
]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index + j]
)
index += n_accepted_ids
cumulative_length += input_length
# Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.input_lengths_tensor += accepted_ids
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# These values can be updated without a GPU -> CPU sync
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs:
# Get prefill logprobs
......@@ -1693,183 +1930,292 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# Does a GPU <-> CPU sync internally
if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
# Update values if we need to continue prefilling
# This represents the `else` case of the `Update values` if above
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
if prefill and not finished_prefilling:
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert batch.speculative_ids is None
all_postfix_ids = []
for i, (
request_prefilling,
next_token_id,
all_input_ids,
cache_length,
input_length,
next_chunk_length,
) in enumerate(
zip(
batch.prefilling_mask,
next_token_ids,
batch.all_input_ids,
batch.cache_lengths,
batch.input_lengths,
next_chunk_lengths,
)
):
if request_prefilling:
next_cache_length = cache_length + input_length
# Get new prompt IDs to prefill
postfix_ids = all_input_ids[
next_cache_length : next_cache_length + next_chunk_length
]
else:
# This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id]
all_postfix_ids.append(postfix_ids)
batch.input_ids = all_postfix_ids
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator
iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# Reset max_input_length
batch.max_input_length = 0
# For each member of the batch
index = 0
for i, (
request,
prompt_length,
cache_length,
input_length,
prefix_offset,
read_offset,
stopping_criteria,
all_input_ids,
prefix_ids,
do_sample,
seed,
top_n_tokens,
request_was_prefilling,
request_is_prefilling,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
# Compute logprobs first as, even though we might skip the token,
# it can still be required to compute the logprobs
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
# this state to be stable
if request.id % self.world_size == self.rank:
# Prefill
if prefill and request.prefill_logprobs:
if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
if not request_is_prefilling:
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index
]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[
cache_length + 1 : cache_length + input_length + 1
]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * (
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = (
all_input_ids[: cache_length + 1] + prefill_token_ids
)
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = (
[float("nan")] * (len(prefix_ids) + 1)
) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefix_ids + prefill_token_ids,
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = Tokens(
prefix_ids + prefill_token_ids,
prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
)
if past_prefill_logprob_tokens is not None:
prefill_logprob_tokens = (
past_prefill_logprob_tokens + prefill_logprob_tokens
)
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
else:
prefill_tokens = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
batch.prefill_logprob_tokens[i] = None
# If it is, the tokens we decoded should be ignored
if request_is_prefilling:
# Make sure that we do not stop as even though this request did not create a token, it is still
# processing
stopped = False
new_input_length = next_chunk_lengths[i]
else:
new_input_length = n_accepted_ids
# Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request.id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
prefill_tokens,
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
else:
generated_text = None
if top_n_tokens > 0:
all_top_tokens = []
for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids
for token_id in top_token_ids
]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
batch.prefill_logprob_tokens[i],
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation)
generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(
i, next_token_id
)
)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
index += n_accepted_ids
current_cache_length = cache_length + input_length
batch.cache_lengths[i] = current_cache_length
current_input_length = new_input_length
batch.max_input_length = max(batch.max_input_length, current_input_length)
batch.input_lengths[i] = current_input_length
current_length = current_cache_length + current_input_length
batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
......@@ -1880,9 +2226,13 @@ class FlashCausalLM(Model):
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
if prefill and finished_prefilling:
# We do not need prefill tensors anymore
batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
......@@ -1894,7 +2244,7 @@ class FlashCausalLM(Model):
block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None,
) -> ContextManager:
if ATTENTION != "flashinfer":
......@@ -1905,8 +2255,6 @@ class FlashCausalLM(Model):
use_prefill_with_paged_kv_state,
)
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None:
return use_prefill_with_paged_kv_state(
state=(
......@@ -1915,11 +2263,11 @@ class FlashCausalLM(Model):
# block_tables=block_tables_to_ragged(
# block_tables=block_tables,
# input_lengths=input_lengths,
# prefix_lens=prefix_lens,
# cache_lengths=cache_lengths,
# ),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
......@@ -1931,7 +2279,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None
return use_decode_state(
state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
input_lengths=input_lengths_tensor + cache_lengths_tensor,
block_tables=block_tables,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
......@@ -1943,19 +2291,19 @@ class FlashCausalLM(Model):
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int]
*, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens)
assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(prefix_lens)
total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)):
seq_len = prefix_len + input_length
for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
......
......@@ -5,9 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
......@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
......
......@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod
......
......@@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod
......
from io import BytesIO
from PIL import Image
import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
......@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
......@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
......@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
......@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
......@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
......@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
......@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
if batch.pixel_values is not None:
......@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
# Replay the graph
cuda_graph["graph"].replay()
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
......
......@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
......@@ -31,6 +40,7 @@ class Model(ABC):
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
):
self.model_id = model_id
self.model = model.eval()
......@@ -60,6 +70,29 @@ class Model(ABC):
speculate = get_speculate()
self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
......@@ -78,6 +111,10 @@ class Model(ABC):
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
)
@property
......
......@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self.decoder_input_ids),
)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment