Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
import pytest import pytest
import base64
import asyncio import asyncio
...@@ -15,22 +14,8 @@ async def mllama(mllama_handle): ...@@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot): async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat( response = await mllama.chat(
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
......
...@@ -68,7 +68,7 @@ fn get_config( ...@@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability(); let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
if prefix_caching.is_none() { if prefix_caching.is_none() {
...@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> ...@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
} }
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string()); let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string()); let prefix_caching = prefix_caching.unwrap_or("true".to_string());
...@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
}; };
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching); std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention); std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
...@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> { ...@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(), "`max_input_tokens must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
...@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> { ...@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
......
...@@ -34,6 +34,10 @@ message InfoResponse { ...@@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
} }
/// Empty request /// Empty request
...@@ -135,10 +139,14 @@ message Request { ...@@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Tokens that can be retrieved from the KV cache.
uint32 prefix_len = 12; /// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation /// Context truncation
bool add_special_tokens = 13; bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
} }
message Batch { message Batch {
...@@ -163,6 +171,8 @@ message CachedBatch { ...@@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
} }
enum FinishReason { enum FinishReason {
...@@ -220,6 +230,8 @@ message FilterBatchResponse { ...@@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
} }
message PrefillResponse { message PrefillResponse {
...@@ -233,6 +245,8 @@ message PrefillResponse { ...@@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
} }
message DecodeRequest { message DecodeRequest {
......
...@@ -18,45 +18,6 @@ use tracing::warn; ...@@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import os import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1" os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer" os.environ["ATTENTION"] = "flashinfer"
......
...@@ -9,6 +9,9 @@ from typing import Callable, Any ...@@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept( async def intercept(
self, self,
method: Callable, method: Callable,
...@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): ...@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from # Runtime Error cannot be recovered from
if isinstance(err, RuntimeError): if isinstance(err, RuntimeError):
exit(1) self.shutdown_callback()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass
class Seqlen:
@dataclass input_lengths: torch.Tensor
class Seqlen: cache_lengths: torch.Tensor
input_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor]
prefix_lengths: torch.Tensor cu_seqlen_k: Optional[torch.Tensor]
cu_seqlen_q: Optional[torch.Tensor] max_q: int
cu_seqlen_k: Optional[torch.Tensor] max_k: int
max_q: int
max_k: int def __init__(
self,
def __init__( input_lengths,
self, cache_lengths,
input_lengths, cu_seqlen_q=None,
prefix_lengths, max_q=None,
cu_seqlen_q=None, max_k=None,
max_q=None, ):
max_k=None, self.input_lengths = input_lengths
): self.cache_lengths = cache_lengths
self.input_lengths = input_lengths device = self.input_lengths.device
self.prefix_lengths = prefix_lengths shape = self.input_lengths.shape
device = self.input_lengths.device if cu_seqlen_q is None:
shape = self.input_lengths.shape cu_seqlen_q = torch.arange(
if cu_seqlen_q is None: shape[0] + 1,
cu_seqlen_q = torch.arange( device=device,
shape[0] + 1, dtype=torch.int32,
device=device, )
dtype=torch.int32, max_q = 1
) else:
max_q = 1 assert max_q is not None
else: assert max_k is not None
assert max_q is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cuda graphs don't like this and this is necessary to clamp within mistral # cu_seqlen_k[0] = 0
# Although FA2 might not want the clamping total = self.input_lengths + self.cache_lengths
# cu_seqlen_k[0] = 0 torch.cumsum(total, -1, out=cu_seqlen_k[1:])
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.cu_seqlen_q = cu_seqlen_q self.max_q = max_q
self.cu_seqlen_k = cu_seqlen_k self.max_k = max_k
self.max_q = max_q
self.max_k = max_k def clamp(self, max):
# Flash decoding doesn't need to clamp
def clamp(self, max): return self
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self
...@@ -123,7 +123,7 @@ def paged_attention( ...@@ -123,7 +123,7 @@ def paged_attention(
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query) out = torch.empty_like(query)
...@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer": ...@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer":
window_left=window_size_left, window_left=window_size_left,
) )
elif V2: elif ATTENTION == "flashdecoding":
if V2:
def attention( def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q, q,
key_cache, key_cache: torch.Tensor,
value_cache, value_cache: torch.Tensor,
out, seqlen: Seqlen,
seqlen.cu_seqlen_q, block_tables: torch.Tensor,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale, softmax_scale,
False, window_size_left=-1,
causal, causal=True,
window_size_left, softcap=0.0,
0, ):
softcap, out = torch.empty_like(q)
False, if window_size_left <= 0 and window_size_left != -1:
None, raise ValueError("`window_size_left` must be > 0 or -1")
)[0] return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else: else:
def attention( def attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
window_size_left: int = -1, window_size_left: int = -1,
causal: bool = True, causal: bool = True,
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
) )
if v.shape[1] != q.shape[1]: if softcap is not None:
# MQA expand raise NotImplementedError(
if v.shape[1] == 1: "softcap is only available with flash attn v2"
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
) )
out = torch.empty_like(q) # Flash attention v1 requires q, k and v to have the same number of heads
flash_attn_cuda.fwd( if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
elif ATTENTION == "paged":
if V2:
def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
out, seqlen: Seqlen,
seqlen.cu_seqlen_q, block_tables: torch.Tensor,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale, softmax_scale,
False, window_size_left=-1,
causal, causal=True,
False, softcap=0.0,
0, ):
None, out = torch.empty_like(q)
) if window_size_left <= 0 and window_size_left != -1:
return out raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we # Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",
......
...@@ -699,7 +699,6 @@ def check_args( ...@@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
......
...@@ -66,6 +66,7 @@ def paged_attention( ...@@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(query) out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
...@@ -74,7 +75,7 @@ def paged_attention( ...@@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
seqlen.input_lengths, input_lengths,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,
......
...@@ -104,7 +104,7 @@ def paged_attention( ...@@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query) out = torch.empty_like(query)
......
...@@ -76,6 +76,7 @@ class CausalLMBatch(Batch): ...@@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.input_ids),
) )
@classmethod @classmethod
......
...@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module): ...@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor, aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( (
pixel_values.shape batch_size,
) num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape( pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width batch_size * num_concurrent_media * num_tiles, num_channels, height, width
......
...@@ -16,7 +16,17 @@ from transformers import ( ...@@ -16,7 +16,17 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
GenerationConfig, GenerationConfig,
) )
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict from typing import (
Any,
ContextManager,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
...@@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks ...@@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import (
get_support_chunking,
get_max_prefill_tokens,
)
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
...@@ -60,7 +74,6 @@ from text_generation_server.utils.import_utils import ( ...@@ -60,7 +74,6 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
...@@ -117,45 +130,48 @@ class FlashCausalLMBatch(Batch): ...@@ -117,45 +130,48 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping: Dict[int, int] requests_idx_mapping: Dict[int, int]
# Decoder values # Decoder values
input_ids: torch.Tensor # Can be a list for easy filtering
position_ids: torch.Tensor # If `input_ids` is a list, it needs to be materialized to a tensor first
input_ids: Union[torch.Tensor, List[List[int]]]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
position_ids: Optional[torch.Tensor]
speculative_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor]
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Paged Attention values
# Set when creating the batch # Set when creating the batch
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slot_indices: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size # list of length b of list of length s_i // block_size
block_tables: List[List[int]] block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
# size [b], containing the number of blocks that can be retrieved from the cache slots: Optional[torch.Tensor]
prefix_lens: List[int]
prefix_lens_tensor: torch.Tensor
max_seqlen: int max_input_length: int
max_current_length: int
# Whether this batch contains at least one request that is prefilling
prefilling: bool
# Whether each request is prefilling
prefilling_mask: List[bool]
# Prefill metadata tensors to efficiently compute logprobs # Prefill metadata tensors to efficiently compute logprobs
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_head_indices: Optional[torch.Tensor] prefill_head_indices: Optional[torch.Tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_next_token_indices: Optional[torch.tensor] prefill_next_token_indices: Optional[torch.tensor]
# Will be set by `generate_token` and reset after each prefill forward
prefill_cu_outlens: Optional[List[int]] prefill_cu_outlens: Optional[List[int]]
# Will be set by `generate_token` and reset after each prefill forward
# Prefixes prefill_logprob_tokens: List[Optional[Tokens]]
prefix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
...@@ -163,7 +179,14 @@ class FlashCausalLMBatch(Batch): ...@@ -163,7 +179,14 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
input_lengths_tensor: torch.Tensor # size [b], containing the number of blocks that can be retrieved from the cache
cache_lengths: List[int]
prompt_lengths: List[int]
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
input_lengths_tensor: Optional[torch.Tensor]
cache_lengths_tensor: Optional[torch.Tensor]
prompt_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
...@@ -174,7 +197,8 @@ class FlashCausalLMBatch(Batch): ...@@ -174,7 +197,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor: torch.Tensor top_n_tokens_tensor: torch.Tensor
# Adapter metadata for each request # Adapter metadata for each request
adapter_meta: AdapterBatchMetadata # Will be set by `generate_token` and reset after each prefill forward before staying set in decode
adapter_meta: Optional[AdapterBatchMetadata]
# Number of blocks in this batch # Number of blocks in this batch
num_blocks: int num_blocks: int
...@@ -187,6 +211,11 @@ class FlashCausalLMBatch(Batch): ...@@ -187,6 +211,11 @@ class FlashCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE, max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
) )
@classmethod @classmethod
...@@ -218,46 +247,28 @@ class FlashCausalLMBatch(Batch): ...@@ -218,46 +247,28 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows() speculate = get_speculate()
position_ids = []
cu_seqlen_prefill = [0]
start_slots = []
slot_indices = []
prefill_cache_indices = []
cache_lengths = []
input_lengths = [] input_lengths = []
prompt_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] all_postfix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
adapter_indices_list = []
adapter_set = set()
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
num_blocks = 0 num_blocks = 0
max_seqlen = 0 max_input_length = 0
max_current_length = 0
max_length = 0 max_length = 0
max_blocks = 0 max_blocks = 0
block_tables = [] block_tables = []
slots = []
prefix_lens = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
...@@ -266,38 +277,47 @@ class FlashCausalLMBatch(Batch): ...@@ -266,38 +277,47 @@ class FlashCausalLMBatch(Batch):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
orig_input_length = len(tokenized_input) prompt_length = len(tokenized_input)
prompt_lengths.append(prompt_length)
cache_length = r.cache_len
prefix_len = r.prefix_len
assert ( assert (
prefix_len <= orig_input_length cache_length <= prompt_length
), f"Prefix {prefix_len} vs input {orig_input_length}" ), f"Prefix {cache_length} vs input {prompt_length}"
if prefix_len == orig_input_length: if cache_length == prompt_length:
assert prefix_len > 0 assert False, "unreachable"
prefix_len -= 1
# `chunk_len` is an optional field in the protobuf
# Commented as it's costly. # It is only set if the model support chunking
# log_master(logger.debug, "Tokenized input ids {tokenized_input}") if r.HasField("chunk_len"):
prefix_ids.append(tokenized_input[:prefix_len]) input_length = r.chunk_len
tokenized_input = tokenized_input[prefix_len:]
if cache_length + input_length < prompt_length:
input_length = len(tokenized_input) # FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
assert get_support_chunking()
assert input_length > 0
postfix_ids = tokenized_input[
cache_length : cache_length + input_length
]
assert (
len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned"
else:
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
input_length = len(postfix_ids)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(prompt_length - 5)
read_offsets.append(input_length) read_offsets.append(prompt_length)
all_postfix_ids.append(postfix_ids)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
...@@ -307,22 +327,13 @@ class FlashCausalLMBatch(Batch): ...@@ -307,22 +327,13 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens) top_n_tokens.append(r.top_n_tokens)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((input_length,), adapter_index))
adapter_set.add(adapter_index)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
# Tokens that need to be mapped to blocks. # Tokens that need to be mapped to blocks.
block_tokens = orig_input_length + max_new_tokens - 1 + speculative_length block_tokens = prompt_length + max_new_tokens - 1 + speculative_length
# Tokens that need to be mapped to slots. We don't need slots for the
# cached prefix (if present).
slot_tokens = input_length + max_new_tokens - 1 + speculative_length
# blocks and slots can be empty (for example in warmup) # blocks and slots can be empty (for example in warmup)
if not r.blocks: if not r.blocks:
...@@ -330,77 +341,26 @@ class FlashCausalLMBatch(Batch): ...@@ -330,77 +341,26 @@ class FlashCausalLMBatch(Batch):
request_blocks = [ request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks) b for b in range(num_blocks, num_blocks + needed_blocks)
] ]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else: else:
request_blocks = r.blocks request_blocks = r.blocks
request_slots = r.slots[
prefix_len: #: orig_input_length + max_new_tokens + speculative_length
]
block_tables.append(request_blocks) block_tables.append(request_blocks)
slots.extend(request_slots) cache_lengths.append(cache_length)
prefix_lens.append(prefix_len)
num_blocks += len(request_blocks) num_blocks += len(request_blocks)
start_slots.append(cumulative_slot_tokens)
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update # Update
cumulative_length += input_length
cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_input_length = max(max_input_length, input_length)
max_current_length = max(max_current_length, cache_length + input_length)
max_length = max( max_length = max(
max_length, input_length + max_new_tokens + speculative_length max_length,
prompt_length + max_new_tokens + speculative_length,
) )
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor # Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros( all_input_ids_tensor = np.zeros(
...@@ -414,103 +374,59 @@ class FlashCausalLMBatch(Batch): ...@@ -414,103 +374,59 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor, dtype=torch.int64, device=device all_input_ids_tensor, dtype=torch.int64, device=device
) )
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor( top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros( block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" (len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
) )
for i, request_blocks in enumerate(block_tables): for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device) block_tables_tensor = block_tables_tensor.to(device)
prefix_lens_tensor = torch.tensor(prefix_lens, dtype=torch.int32, device=device) prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids, input_ids=all_postfix_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
prefill_cache_indices=prefill_cache_indices,
start_slots=start_slots,
slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, cache_lengths=cache_lengths,
prefix_lens=prefix_lens, max_input_length=max_input_length,
prefix_lens_tensor=prefix_lens_tensor, max_current_length=max_current_length,
max_seqlen=max_seqlen, prefilling=True,
prefill_head_indices=prefill_head_indices, prefilling_mask=[True] * len(pb.requests),
prefill_next_token_indices=prefill_next_token_indices, prefill_logprob_tokens=[None] * len(pb.requests),
prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, prompt_lengths=prompt_lengths,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
adapter_meta=AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
speculative_ids=None, speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=None,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
cache_lengths_tensor=None,
input_lengths_tensor=None,
adapter_meta=None,
) )
@classmethod @classmethod
...@@ -533,7 +449,7 @@ class FlashCausalLMBatch(Batch): ...@@ -533,7 +449,7 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self
device = self.input_ids.device device = self.block_tables_tensor.device
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
...@@ -548,19 +464,23 @@ class FlashCausalLMBatch(Batch): ...@@ -548,19 +464,23 @@ class FlashCausalLMBatch(Batch):
# Create on CPU to only move to GPU once instead of at every copy # Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64) slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_seqlen = 0 max_input_length = 0
max_current_length = 0
requests = [] requests = []
start_slots = []
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] input_ids = []
prompt_lengths = []
input_lengths = [] input_lengths = []
prefix_lens = [] cache_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
prefilling_mask = []
prefill_logprob_tokens = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
adapter_set = set() adapter_set = set()
...@@ -577,16 +497,23 @@ class FlashCausalLMBatch(Batch): ...@@ -577,16 +497,23 @@ class FlashCausalLMBatch(Batch):
requests.append(self.requests[idx]) requests.append(self.requests[idx])
# Prefilling
request_prefilling = self.prefilling_mask[idx]
prefilling_mask.append(request_prefilling)
# Get length # Get length
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
prefix_len = self.prefix_lens[idx] request_cache_length = self.cache_lengths[idx]
max_seqlen = max(max_seqlen, request_input_length) max_input_length = max(max_input_length, request_input_length)
max_current_length = max(
max_current_length, request_cache_length + request_input_length
)
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_lens.append(prefix_len) cache_lengths.append(request_cache_length)
prefix_offsets.append(self.prefix_offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx]) read_offsets.append(self.read_offsets[idx])
...@@ -594,60 +521,79 @@ class FlashCausalLMBatch(Batch): ...@@ -594,60 +521,79 @@ class FlashCausalLMBatch(Batch):
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
request_block_table = self.block_tables[idx] request_block_table = self.block_tables[idx]
num_blocks += len(request_block_table) num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU) # Input ids if the request was part of a prefilling batch
slot_indices[i] = cumulative_max_length + request_input_length - 1 # If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Set slice # Set slice
slot_filtering_indices[ slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx] self.slot_indices[idx] : self.slot_indices[idx]
+ request_input_length + request_input_length
+ remaining_tokens + remaining_tokens
- 1 - 1
] = True ] = True
cumulative_max_length += request_input_length + remaining_tokens - 1 cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lens_tensor = self.prefix_lens_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = ( speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None self.speculative_ids[indices] if self.speculative_ids is not None else None
) )
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
if self.prefilling:
# Move to GPU now that we have the whole tensor # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
slot_indices = slot_indices.to(device) position_ids = None
slot_indices = None
adapter_segments, adapter_segment_indices = find_segments(adapter_indices) slots = None
adapter_segments = torch.tensor( cache_lengths_tensor = None
adapter_segments, dtype=torch.int32, device=device input_lengths_tensor = None
) adapter_meta = None
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() else:
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
slot_indices = slot_indices.to(device)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return type(self)( return type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
...@@ -657,24 +603,28 @@ class FlashCausalLMBatch(Batch): ...@@ -657,24 +603,28 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_lens=prefix_lens, cache_lengths=cache_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
...@@ -682,12 +632,7 @@ class FlashCausalLMBatch(Batch): ...@@ -682,12 +632,7 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata( adapter_meta=adapter_meta,
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
),
) )
@classmethod @classmethod
...@@ -697,74 +642,98 @@ class FlashCausalLMBatch(Batch): ...@@ -697,74 +642,98 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
prefilling = False
num_blocks = 0 num_blocks = 0
total_batch_size = 0 total_batch_size = 0
total_slots = 0 total_slots = 0
max_blocks = 0 max_blocks = 0
max_length = 0 max_length = 0
max_seqlen = 0 max_input_length = 0
max_current_length = 0
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
num_blocks += b.num_blocks num_blocks += b.num_blocks
speculative_length = ( speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
) )
max_blocks = max(max_blocks, b.max_blocks) max_input_length = max(max_input_length, b.max_input_length)
max_seqlen = max(max_seqlen, b.max_seqlen) max_current_length = max(max_current_length, b.max_current_length)
max_length = max( max_length = max(
max_length, max_length,
max( max(
input_length prompt_length
+ stopping_criteria.max_new_tokens + stopping_criteria.max_new_tokens
+ speculative_length + speculative_length
- stopping_criteria.current_tokens for prompt_length, stopping_criteria in zip(
for input_length, stopping_criteria in zip( b.prompt_lengths, b.stopping_criterias
b.input_lengths, b.stopping_criterias
) )
), ),
) )
prefilling = prefilling or b.prefilling
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slots = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(
total_batch_size
)
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_segment_builder = SegmentConcatBuilder()
adapter_set = set()
input_ids = batches[0].input_ids.new_empty(total_batch_size) prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size total_batch_size
) )
block_tables_tensor = batches[0].block_tables_tensor.new_zeros( block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks) (total_batch_size, max_blocks)
) )
prefix_lens_tensor = batches[0].prefix_lens_tensor.new_empty(total_batch_size)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size, total_batch_size,
) )
total_indices_size = sum(
b.adapter_meta.adapter_indices.shape[0] for b in batches
)
adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(
total_indices_size
)
adapter_set = set()
adapter_segment_builder = SegmentConcatBuilder()
start_slots = []
block_tables = [] block_tables = []
prefix_lens = [] cache_lengths = []
all_input_ids = [] all_input_ids = []
prefix_ids = []
prompt_lengths = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
prefill_logprob_tokens = []
next_token_chooser_parameters = [] next_token_chooser_parameters = []
fsm_grammar_states = [] fsm_grammar_states = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
prefilling_mask = []
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
...@@ -783,32 +752,9 @@ class FlashCausalLMBatch(Batch): ...@@ -783,32 +752,9 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
)
all_input_ids_tensor[ all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1] start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length] ] = batch.all_input_ids_tensor[:, :max_length]
...@@ -816,20 +762,56 @@ class FlashCausalLMBatch(Batch): ...@@ -816,20 +762,56 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor[ block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1] start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks] ] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
prefix_lens_tensor[start_index:end_index] = batch.prefix_lens_tensor if not prefilling:
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor
# Copy over adapter indices
adapter_start_index = cumulative_adapter_indices_size
adapter_end_index = (
cumulative_adapter_indices_size
+ batch.adapter_meta.adapter_indices.shape[0]
)
adapter_indices[adapter_start_index:adapter_end_index] = (
batch.adapter_meta.adapter_indices
)
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
start_slots.append(batch.start_slots + cumulative_slots) # Update
cumulative_slots += len(batch.slots)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
input_ids.extend(batch.input_ids)
prefilling_mask.extend(batch.prefilling_mask)
block_tables.extend(batch.block_tables) block_tables.extend(batch.block_tables)
prefix_lens.extend(batch.prefix_lens) cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets) read_offsets.extend(batch.read_offsets)
prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
...@@ -838,11 +820,6 @@ class FlashCausalLMBatch(Batch): ...@@ -838,11 +820,6 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
start_slots = torch.concat(start_slots)
# assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum()
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
...@@ -858,7 +835,14 @@ class FlashCausalLMBatch(Batch): ...@@ -858,7 +835,14 @@ class FlashCausalLMBatch(Batch):
else None else None
) )
adapter_segments, adapter_segment_indices = adapter_segment_builder.build() if adapter_segment_builder is not None:
adapter_segments, adapter_segment_indices = adapter_segment_builder.build()
adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
)
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
...@@ -868,24 +852,28 @@ class FlashCausalLMBatch(Batch): ...@@ -868,24 +852,28 @@ class FlashCausalLMBatch(Batch):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None, prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lens=prefix_lens, cache_lengths=cache_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
slots=slots, slots=slots,
max_seqlen=max_seqlen, max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
prefilling_mask=prefilling_mask,
prefill_head_indices=None, prefill_head_indices=None,
prefill_next_token_indices=None, prefill_next_token_indices=None,
prefill_cu_outlens=None, prefill_cu_outlens=None,
prefill_logprob_tokens=prefill_logprob_tokens,
prompt_lengths=prompt_lengths,
prompt_lengths_tensor=prompt_lengths_tensor,
input_lengths=input_lengths, input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets, prefix_offsets=prefix_offsets,
read_offsets=read_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
...@@ -893,12 +881,195 @@ class FlashCausalLMBatch(Batch): ...@@ -893,12 +881,195 @@ class FlashCausalLMBatch(Batch):
num_blocks=num_blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
adapter_meta=AdapterBatchMetadata( adapter_meta=adapter_meta,
adapter_indices=adapter_indices, )
adapter_set=adapter_set,
adapter_segments=adapter_segments, def prepare_for_prefill(self):
segment_indices=adapter_segment_indices, # Prepare values if we need to continue prefilling
), # Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert self.speculative_ids is None
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
# Cumulative length
cumulative_length = 0
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
slots = []
adapter_indices_list = []
adapter_set = set()
for i, (
r,
cache_length,
input_length,
prompt_length,
request_prefilling,
blocks,
) in enumerate(
zip(
self.requests,
self.cache_lengths,
self.input_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables,
)
):
next_chunk_length = input_length
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
request_slots = request_slots[cache_length:]
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
adapter_set.add(adapter_index)
# Update
cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots)
device = self.block_tables_tensor.device
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
if len(self) > 1:
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
self.prefill_cu_outlens = prefill_cu_outlens
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
self.cu_seqlen_prefill = cu_seqlen_prefill
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
adapter_segments=adapter_segments,
segment_indices=adapter_segment_indices,
) )
def __len__(self): def __len__(self):
...@@ -938,6 +1109,7 @@ class FlashCausalLM(Model): ...@@ -938,6 +1109,7 @@ class FlashCausalLM(Model):
head_size: Optional[int] = None, head_size: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
kv_cache_dtype: Optional[torch.dtype] = None, kv_cache_dtype: Optional[torch.dtype] = None,
support_chunking: bool = True,
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
...@@ -1065,6 +1237,7 @@ class FlashCausalLM(Model): ...@@ -1065,6 +1237,7 @@ class FlashCausalLM(Model):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
support_chunking=support_chunking,
) )
@property @property
...@@ -1101,11 +1274,11 @@ class FlashCausalLM(Model): ...@@ -1101,11 +1274,11 @@ class FlashCausalLM(Model):
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = [max_s] * bs input_lengths = [max_s] * bs
prefix_lengths = [0] * bs cache_lengths = [0] * bs
input_lengths_tensor = ( input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
) )
prefix_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
block_tables = torch.arange( block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device max_bt, dtype=torch.int32, device=self.device
).repeat(bs) ).repeat(bs)
...@@ -1115,7 +1288,7 @@ class FlashCausalLM(Model): ...@@ -1115,7 +1288,7 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lens=prefix_lengths, cache_lengths=cache_lengths,
) )
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
...@@ -1144,7 +1317,7 @@ class FlashCausalLM(Model): ...@@ -1144,7 +1317,7 @@ class FlashCausalLM(Model):
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor, "cache_lengths": cache_lengths_tensor,
"state": state, "state": state,
"graph": graph, "graph": graph,
} }
...@@ -1156,11 +1329,11 @@ class FlashCausalLM(Model): ...@@ -1156,11 +1329,11 @@ class FlashCausalLM(Model):
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor, input_lengths_tensor=input_lengths_tensor,
state=state, state=state,
prefix_lens_tensor=prefix_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
max_k=max_s, max_k=max_s,
...@@ -1184,7 +1357,7 @@ class FlashCausalLM(Model): ...@@ -1184,7 +1357,7 @@ class FlashCausalLM(Model):
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths_tensor, input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None, cu_seqlen_q=None,
max_q=1, max_q=1,
max_k=max_s, max_k=max_s,
...@@ -1207,6 +1380,7 @@ class FlashCausalLM(Model): ...@@ -1207,6 +1380,7 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
self.kv_cache = []
empty_cache() empty_cache()
try: try:
...@@ -1226,7 +1400,7 @@ class FlashCausalLM(Model): ...@@ -1226,7 +1400,7 @@ class FlashCausalLM(Model):
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
...@@ -1341,14 +1515,16 @@ class FlashCausalLM(Model): ...@@ -1341,14 +1515,16 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) cache_lengths_tensor = torch.zeros(
seqlen, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen max_s = seqlen
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=1, max_q=1,
max_k=seqlen, max_k=seqlen,
...@@ -1380,7 +1556,7 @@ class FlashCausalLM(Model): ...@@ -1380,7 +1556,7 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
...@@ -1399,8 +1575,8 @@ class FlashCausalLM(Model): ...@@ -1399,8 +1575,8 @@ class FlashCausalLM(Model):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
...@@ -1422,10 +1598,12 @@ class FlashCausalLM(Model): ...@@ -1422,10 +1598,12 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
print(slots)
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache # In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode. # in a circular buffer mode.
...@@ -1445,21 +1623,20 @@ class FlashCausalLM(Model): ...@@ -1445,21 +1623,20 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -1486,7 +1663,7 @@ class FlashCausalLM(Model): ...@@ -1486,7 +1663,7 @@ class FlashCausalLM(Model):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
# assert block_tables.shape[0] >= slots.shape[0] # assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
...@@ -1501,14 +1678,16 @@ class FlashCausalLM(Model): ...@@ -1501,14 +1678,16 @@ class FlashCausalLM(Model):
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["prefix_lengths"].zero_() cuda_graph["cache_lengths"].zero_()
cuda_graph["prefix_lengths"][: prefix_lens_tensor.shape[0]] = prefix_lens_tensor cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"], input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens_tensor=cuda_graph["prefix_lengths"], cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"], state=cuda_graph["state"],
): ):
# Replay the graph # Replay the graph
...@@ -1528,7 +1707,10 @@ class FlashCausalLM(Model): ...@@ -1528,7 +1707,10 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None prefill = batch.prefilling
if prefill:
batch.prepare_for_prefill()
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
# Update adapter indices for speculative tokens (if present) # Update adapter indices for speculative tokens (if present)
...@@ -1570,14 +1752,62 @@ class FlashCausalLM(Model): ...@@ -1570,14 +1752,62 @@ class FlashCausalLM(Model):
if prefill_logprobs if prefill_logprobs
else speculative_logits else speculative_logits
) )
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( if len(batch) > 1 and prefill_logprobs:
len(batch) # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
) # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
else: else:
prefill_logprobs = None
next_token_logits = out next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices next_adapter_indices = batch.adapter_meta.adapter_indices
finished_prefilling = True
next_chunk_lengths = []
current_prefilling_mask = batch.prefilling_mask
if prefill:
if get_support_chunking():
next_prefilling_mask = []
# Budget in tokens for the next batch
# We remove (len(batch) - 1) to always have enough space for at least a single decode
# for the remaining requests -1 because the first request does not need to be removed from the budget
# (ex: you have one request in the batch, you want it to take the full budget not budget -1)
batch_budget = get_max_prefill_tokens() - (len(batch) - 1)
# We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead
for cache_length, input_length, prompt_length in zip(
reversed(batch.cache_lengths),
reversed(batch.input_lengths),
reversed(batch.prompt_lengths),
):
remaining_prefill_tokens = max(
prompt_length - cache_length - input_length, 0
)
if remaining_prefill_tokens > 0:
next_chunk_length = max(
min(remaining_prefill_tokens, batch_budget), 1
)
batch_budget -= next_chunk_length
finished_prefilling = False
next_prefilling_mask.append(True)
else:
# FIXME: use true number of accepted tokens instead of 1
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_prefilling_mask.append(False)
next_chunk_lengths.append(next_chunk_length)
# Reverse back the obtained values²
next_chunk_lengths.reverse()
next_prefilling_mask.reverse()
else:
# The model does not support chunking
# We know we only do a single prefill
finished_prefilling = True
next_prefilling_mask = [False] * len(batch)
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
speculate = get_speculate() speculate = get_speculate()
( (
next_input_ids, next_input_ids,
...@@ -1586,7 +1816,7 @@ class FlashCausalLM(Model): ...@@ -1586,7 +1816,7 @@ class FlashCausalLM(Model):
accepted_ids, accepted_ids,
speculative_ids, speculative_ids,
) = batch.next_token_chooser( ) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], batch.all_input_ids_tensor[:, : batch.max_current_length],
next_token_logits, next_token_logits,
speculate, speculate,
batch.speculative_ids, batch.speculative_ids,
...@@ -1597,29 +1827,28 @@ class FlashCausalLM(Model): ...@@ -1597,29 +1827,28 @@ class FlashCausalLM(Model):
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )
if prefill: # Since we are done prefilling, all the tensors that were concatenating values for all the requests
if len(batch) > 1 and prefill_logprobs: # instantly become of shape [BATCH_SIZE]
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs if prefill and finished_prefilling:
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
next_position_ids = batch.position_ids.new_empty(len(batch)) next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
batch.cu_seqlen_prefill = None len(batch)
else: )
prefill_logprobs = None elif not prefill:
next_position_ids = batch.position_ids next_position_ids = batch.position_ids
# Cumulative length
cumulative_length = 0
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids) iterator = zip(
batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths,
batch.all_input_ids,
accepted_ids,
current_prefilling_mask,
batch.prefilling_mask,
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
...@@ -1627,16 +1856,22 @@ class FlashCausalLM(Model): ...@@ -1627,16 +1856,22 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Cumulative length
# Indexing metadata cumulative_length = 0
start_index = cumulative_length for i, (
end_index = cumulative_length + input_length request,
prompt_length,
if prefill: cache_length,
input_length,
all_input_ids,
n_accepted_ids,
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
if prefill and finished_prefilling:
# Indexing metadata # Indexing metadata
out_start_index = batch.prefill_cu_outlens[i] _start_index = cumulative_length
out_end_index = batch.prefill_cu_outlens[i + 1] end_index = cumulative_length + input_length
out_length = out_end_index - out_start_index
# Initialize position_ids # Initialize position_ids
# In decode, we do not need this as we can just increment position ids # In decode, we do not need this as we can just increment position ids
...@@ -1648,41 +1883,43 @@ class FlashCausalLM(Model): ...@@ -1648,41 +1883,43 @@ class FlashCausalLM(Model):
end_index - 1 end_index - 1
] ]
# Used to gather prefill logprobs # Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices # Copy batch.all_input_ids_tensor to prefill_token_indices
if prefill_logprobs: if request.prefill_logprobs and request_was_prefilling:
if len(batch) > 1: # Indexing metadata
prefill_tokens_indices[out_start_index : out_end_index - 1] = ( out_start_index = batch.prefill_cu_outlens[i]
batch.input_ids[start_index + 1 : start_index + out_length] out_end_index = batch.prefill_cu_outlens[i + 1]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
ids = batch.all_input_ids_tensor[
i, cache_length + 1 : cache_length + input_length + 1
]
if len(batch) > 1:
prefill_tokens_indices[out_start_index:out_end_index] = ids
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index + j]
)
index += n_accepted_ids
cumulative_length += input_length cumulative_length += input_length
# Update values # Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] # These values can be updated without a GPU -> CPU sync
batch.speculative_ids = speculative_ids if not prefill or (prefill and finished_prefilling):
batch.position_ids = next_position_ids + accepted_ids batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.input_lengths_tensor += accepted_ids batch.speculative_ids = speculative_ids
batch.slot_indices += accepted_ids batch.position_ids = next_position_ids + accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
if prefill: batch.slot_indices += accepted_ids
# adjust segment lengths to account for all request lengths being 1 during decoding batch.adapter_meta.adapter_indices = next_adapter_indices
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
if prefill and prefill_logprobs: if prefill and prefill_logprobs:
# Get prefill logprobs # Get prefill logprobs
...@@ -1693,183 +1930,292 @@ class FlashCausalLM(Model): ...@@ -1693,183 +1930,292 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync # GPU <-> CPU sync
prefill_logprobs = prefill_logprobs.view(-1).tolist() prefill_logprobs = prefill_logprobs.view(-1).tolist()
# Does a GPU <-> CPU sync internally
if prefill and finished_prefilling:
# adjust segment lengths to account for all request lengths being 1 during decoding
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
batch.adapter_meta.adapter_segments = torch.tensor(
adapter_segments,
dtype=torch.int32,
device=batch.adapter_meta.adapter_segments.device,
)
# GPU <-> CPU sync # GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist() next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist() next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist() accepted_ids = accepted_ids.tolist()
# Update values if we need to continue prefilling
# This represents the `else` case of the `Update values` if above
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
if prefill and not finished_prefilling:
# Speculation must be ignored while we prefill even with chunking
# it simplifies everything
assert batch.speculative_ids is None
all_postfix_ids = []
for i, (
request_prefilling,
next_token_id,
all_input_ids,
cache_length,
input_length,
next_chunk_length,
) in enumerate(
zip(
batch.prefilling_mask,
next_token_ids,
batch.all_input_ids,
batch.cache_lengths,
batch.input_lengths,
next_chunk_lengths,
)
):
if request_prefilling:
next_cache_length = cache_length + input_length
# Get new prompt IDs to prefill
postfix_ids = all_input_ids[
next_cache_length : next_cache_length + next_chunk_length
]
else:
# This request is done prefilling, the new id is the one selected the sampling method
postfix_ids = [next_token_id]
all_postfix_ids.append(postfix_ids)
batch.input_ids = all_postfix_ids
start_decode = time.time_ns() start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.prompt_lengths,
batch.cache_lengths,
batch.input_lengths, batch.input_lengths,
batch.prefix_offsets, batch.prefix_offsets,
batch.read_offsets, batch.read_offsets,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
current_prefilling_mask,
batch.prefilling_mask,
accepted_ids, accepted_ids,
batch_top_token_ids, batch_top_token_ids,
batch_top_token_logprobs, batch_top_token_logprobs,
) )
# Reset max_input_length
batch.max_input_length = 0
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, ( for i, (
request, request,
prompt_length,
cache_length,
input_length, input_length,
prefix_offset, prefix_offset,
read_offset, read_offset,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
request_was_prefilling,
request_is_prefilling,
n_accepted_ids, n_accepted_ids,
top_token_ids, top_token_ids,
top_token_logprobs, top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Compute logprobs first as, even though we might skip the token,
next_token_texts = [] # it can still be required to compute the logprobs
left = 0 # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
# this state to be stable
if n_accepted_ids > 1: if request.id % self.world_size == self.rank:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
# Prefill # Prefill
if prefill and request.prefill_logprobs: if request_was_prefilling and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
if not request_is_prefilling:
# The request is dones prefilling, meaning that we started generating new tokens
# The last logprob is a logprob for a generated token that was not part of the prompt
# We need to remove it
out_end_index -= 1
request_prefill_logprobs = prefill_logprobs[
out_start_index:out_end_index
]
# Logprobs generated by the model are for the next token
# So we need to translate the id tensor by 1
prefill_token_ids = all_input_ids[
cache_length + 1 : cache_length + input_length + 1
]
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
if past_prefill_logprob_tokens is None:
# add nan for cached prompt tokens/first token
request_prefill_logprobs = [float("nan")] * (
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = (
all_input_ids[: cache_length + 1] + prefill_token_ids
)
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = (
[float("nan")] * (len(prefix_ids) + 1)
) + prefill_logprobs[out_start_index : out_end_index - 1]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefix_ids + prefill_token_ids, prefill_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_logprob_tokens = Tokens(
prefix_ids + prefill_token_ids, prefill_token_ids,
request_prefill_logprobs, request_prefill_logprobs,
prefill_texts, prefill_texts,
is_special=[], is_special=[],
) )
if past_prefill_logprob_tokens is not None:
prefill_logprob_tokens = (
past_prefill_logprob_tokens + prefill_logprob_tokens
)
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
else: else:
prefill_tokens = None batch.prefill_logprob_tokens[i] = None
if top_n_tokens > 0: # If it is, the tokens we decoded should be ignored
all_top_tokens = [] if request_is_prefilling:
for top_token_ids, top_token_logprobs in zip( # Make sure that we do not stop as even though this request did not create a token, it is still
top_token_ids, top_token_logprobs # processing
): stopped = False
toptoken_texts = self.tokenizer.batch_decode( new_input_length = next_chunk_lengths[i]
top_token_ids, else:
clean_up_tokenization_spaces=False, new_input_length = n_accepted_ids
skip_special_tokens=False, # Append next token to all tokens
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
# Shard generations
# All generations will be appended in the rust sharded client
if request.id % self.world_size == self.rank:
if stop:
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
) )
special_toptokens = [ generated_text = GeneratedText(
token_id in self.all_special_ids output_text,
for token_id in top_token_ids stopping_criteria.current_tokens,
] reason,
top_tokens = Tokens( seed if do_sample else None,
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
) )
all_top_tokens.append(top_tokens) else:
top_tokens = all_top_tokens generated_text = None
else:
top_tokens = None if top_n_tokens > 0:
all_top_tokens = []
generation = Generation( for top_token_ids, top_token_logprobs in zip(
request.id, top_token_ids, top_token_logprobs
prefill_tokens, ):
Tokens( toptoken_texts = self.tokenizer.batch_decode(
_next_token_ids, top_token_ids,
_next_token_logprobs, clean_up_tokenization_spaces=False,
next_token_texts, skip_special_tokens=False,
[nid in self.all_special_ids for nid in _next_token_ids], )
), special_toptokens = [
generated_text, token_id in self.all_special_ids
top_tokens, for token_id in top_token_ids
) ]
top_tokens = Tokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
all_top_tokens.append(top_tokens)
top_tokens = all_top_tokens
else:
top_tokens = None
generation = Generation(
request.id,
batch.prefill_logprob_tokens[i],
Tokens(
_next_token_ids,
_next_token_logprobs,
next_token_texts,
[nid in self.all_special_ids for nid in _next_token_ids],
),
generated_text,
top_tokens,
)
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids: for next_token_id in _next_token_ids:
batch.next_token_chooser = ( batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id) batch.next_token_chooser.advance_grammar_single(
) i, next_token_id
)
)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids index += n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: current_cache_length = cache_length + input_length
batch.max_seqlen = batch.input_lengths[i] batch.cache_lengths[i] = current_cache_length
current_input_length = new_input_length
batch.max_input_length = max(batch.max_input_length, current_input_length)
batch.input_lengths[i] = current_input_length
current_length = current_cache_length + current_input_length
batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
...@@ -1880,9 +2226,13 @@ class FlashCausalLM(Model): ...@@ -1880,9 +2226,13 @@ class FlashCausalLM(Model):
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns) return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None if prefill and finished_prefilling:
batch.prefill_head_indices = None # We do not need prefill tensors anymore
batch.prefill_next_token_indices = None batch.cu_seqlen_prefill = None
batch.prefill_cache_indices = None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
...@@ -1894,7 +2244,7 @@ class FlashCausalLM(Model): ...@@ -1894,7 +2244,7 @@ class FlashCausalLM(Model):
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
input_lengths_tensor: torch.Tensor, input_lengths_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor,
state: Optional[Any] = None, state: Optional[Any] = None,
) -> ContextManager: ) -> ContextManager:
if ATTENTION != "flashinfer": if ATTENTION != "flashinfer":
...@@ -1905,8 +2255,6 @@ class FlashCausalLM(Model): ...@@ -1905,8 +2255,6 @@ class FlashCausalLM(Model):
use_prefill_with_paged_kv_state, use_prefill_with_paged_kv_state,
) )
# has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens)
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
return use_prefill_with_paged_kv_state( return use_prefill_with_paged_kv_state(
state=( state=(
...@@ -1915,11 +2263,11 @@ class FlashCausalLM(Model): ...@@ -1915,11 +2263,11 @@ class FlashCausalLM(Model):
# block_tables=block_tables_to_ragged( # block_tables=block_tables_to_ragged(
# block_tables=block_tables, # block_tables=block_tables,
# input_lengths=input_lengths, # input_lengths=input_lengths,
# prefix_lens=prefix_lens, # cache_lengths=cache_lengths,
# ), # ),
block_tables=block_tables, block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill, cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + prefix_lens_tensor, input_lengths=input_lengths_tensor + cache_lengths_tensor,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
...@@ -1931,7 +2279,7 @@ class FlashCausalLM(Model): ...@@ -1931,7 +2279,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None assert input_lengths_tensor is not None
return use_decode_state( return use_decode_state(
state=state if state is not None else self.decode_state, state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor + prefix_lens_tensor, input_lengths=input_lengths_tensor + cache_lengths_tensor,
block_tables=block_tables, block_tables=block_tables,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
...@@ -1943,19 +2291,19 @@ class FlashCausalLM(Model): ...@@ -1943,19 +2291,19 @@ class FlashCausalLM(Model):
def block_tables_to_ragged( def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], prefix_lens: List[int] *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
) -> torch.Tensor: ) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer.""" """Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(prefix_lens) assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(prefix_lens) total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty( block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device total_len, dtype=torch.int32, device=block_tables.device
) )
offset = 0 offset = 0
for i, (input_length, prefix_len) in enumerate(zip(input_lengths, prefix_lens)): for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
seq_len = prefix_len + input_length seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len offset += seq_len
......
...@@ -5,9 +5,14 @@ from typing import Dict, Optional ...@@ -5,9 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
...@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: ...@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
......
...@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod
......
...@@ -116,6 +116,7 @@ class MambaBatch(Batch): ...@@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod
......
from io import BytesIO
from PIL import Image
import torch import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
...@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): ...@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1 max=config.text_config.vocab_size - 1
) )
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None: if image_inputs is not None:
...@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): ...@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM): class MllamaCausalLM(VlmCausalLM):
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None, adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
...@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
...@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
...@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
...@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
...@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images. # Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None or batch.cross_attention_states is not None
): ):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
if batch.pixel_values is not None: if batch.pixel_values is not None:
...@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["cache_lengths"].zero_()
) cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
# Replay the graph with self._forward_context(
cuda_graph["graph"].replay() block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = ( speculative_logits = (
......
...@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
...@@ -31,6 +40,7 @@ class Model(ABC): ...@@ -31,6 +40,7 @@ class Model(ABC):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID, adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()
...@@ -60,6 +70,29 @@ class Model(ABC): ...@@ -60,6 +70,29 @@ class Model(ABC):
speculate = get_speculate() speculate = get_speculate()
self.speculate = speculate self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
...@@ -78,6 +111,10 @@ class Model(ABC): ...@@ -78,6 +111,10 @@ class Model(ABC):
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate, speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
) )
@property @property
......
...@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.decoder_input_ids),
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment