Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
import pytest import pytest
import base64
import asyncio import asyncio
...@@ -15,22 +14,8 @@ async def mllama(mllama_handle): ...@@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot): async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat( response = await mllama.chat(
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
......
...@@ -68,7 +68,7 @@ fn get_config( ...@@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) { fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability(); let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
if prefix_caching.is_none() { if prefix_caching.is_none() {
...@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> ...@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
} }
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string()); let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string()); let prefix_caching = prefix_caching.unwrap_or("true".to_string());
...@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> { ...@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
}; };
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching); std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention); std::env::set_var("ATTENTION", attention);
let max_input_tokens = { let max_input_tokens = {
...@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> { ...@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(), "`max_input_tokens must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases."); tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
...@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> { ...@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
......
...@@ -34,6 +34,10 @@ message InfoResponse { ...@@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
} }
/// Empty request /// Empty request
...@@ -135,10 +139,14 @@ message Request { ...@@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10; repeated uint32 slots = 10;
/// LORA adapter index /// LORA adapter index
optional string adapter_id = 11; optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache. /// Tokens that can be retrieved from the KV cache.
uint32 prefix_len = 12; /// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation /// Context truncation
bool add_special_tokens = 13; bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
} }
message Batch { message Batch {
...@@ -163,6 +171,8 @@ message CachedBatch { ...@@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
} }
enum FinishReason { enum FinishReason {
...@@ -220,6 +230,8 @@ message FilterBatchResponse { ...@@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest { message PrefillRequest {
/// Batch /// Batch
Batch batch = 1; Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
} }
message PrefillResponse { message PrefillResponse {
...@@ -233,6 +245,8 @@ message PrefillResponse { ...@@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4; uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds /// Total elapsed time in nanoseconds
uint64 total_ns = 5; uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
} }
message DecodeRequest { message DecodeRequest {
......
...@@ -18,45 +18,6 @@ use tracing::warn; ...@@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type /// Hub type
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo { pub struct HubModelInfo {
......
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import os import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1" os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer" os.environ["ATTENTION"] = "flashinfer"
......
...@@ -9,6 +9,9 @@ from typing import Callable, Any ...@@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept( async def intercept(
self, self,
method: Callable, method: Callable,
...@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor): ...@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from # Runtime Error cannot be recovered from
if isinstance(err, RuntimeError): if isinstance(err, RuntimeError):
exit(1) self.shutdown_callback()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass
class Seqlen:
@dataclass input_lengths: torch.Tensor
class Seqlen: cache_lengths: torch.Tensor
input_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor]
prefix_lengths: torch.Tensor cu_seqlen_k: Optional[torch.Tensor]
cu_seqlen_q: Optional[torch.Tensor] max_q: int
cu_seqlen_k: Optional[torch.Tensor] max_k: int
max_q: int
max_k: int def __init__(
self,
def __init__( input_lengths,
self, cache_lengths,
input_lengths, cu_seqlen_q=None,
prefix_lengths, max_q=None,
cu_seqlen_q=None, max_k=None,
max_q=None, ):
max_k=None, self.input_lengths = input_lengths
): self.cache_lengths = cache_lengths
self.input_lengths = input_lengths device = self.input_lengths.device
self.prefix_lengths = prefix_lengths shape = self.input_lengths.shape
device = self.input_lengths.device if cu_seqlen_q is None:
shape = self.input_lengths.shape cu_seqlen_q = torch.arange(
if cu_seqlen_q is None: shape[0] + 1,
cu_seqlen_q = torch.arange( device=device,
shape[0] + 1, dtype=torch.int32,
device=device, )
dtype=torch.int32, max_q = 1
) else:
max_q = 1 assert max_q is not None
else: assert max_k is not None
assert max_q is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cuda graphs don't like this and this is necessary to clamp within mistral # cu_seqlen_k[0] = 0
# Although FA2 might not want the clamping total = self.input_lengths + self.cache_lengths
# cu_seqlen_k[0] = 0 torch.cumsum(total, -1, out=cu_seqlen_k[1:])
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.cu_seqlen_q = cu_seqlen_q self.max_q = max_q
self.cu_seqlen_k = cu_seqlen_k self.max_k = max_k
self.max_q = max_q
self.max_k = max_k def clamp(self, max):
# Flash decoding doesn't need to clamp
def clamp(self, max): return self
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self
...@@ -123,7 +123,7 @@ def paged_attention( ...@@ -123,7 +123,7 @@ def paged_attention(
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query) out = torch.empty_like(query)
...@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer": ...@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer":
window_left=window_size_left, window_left=window_size_left,
) )
elif V2: elif ATTENTION == "flashdecoding":
if V2:
def attention( def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q, q,
key_cache, key_cache: torch.Tensor,
value_cache, value_cache: torch.Tensor,
out, seqlen: Seqlen,
seqlen.cu_seqlen_q, block_tables: torch.Tensor,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale, softmax_scale,
False, window_size_left=-1,
causal, causal=True,
window_size_left, softcap=0.0,
0, ):
softcap, out = torch.empty_like(q)
False, if window_size_left <= 0 and window_size_left != -1:
None, raise ValueError("`window_size_left` must be > 0 or -1")
)[0] return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else: else:
def attention( def attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
window_size_left: int = -1, window_size_left: int = -1,
causal: bool = True, causal: bool = True,
softcap=None, softcap=None,
): ):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
) )
if v.shape[1] != q.shape[1]: if softcap is not None:
# MQA expand raise NotImplementedError(
if v.shape[1] == 1: "softcap is only available with flash attn v2"
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
) )
out = torch.empty_like(q) # Flash attention v1 requires q, k and v to have the same number of heads
flash_attn_cuda.fwd( if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
elif ATTENTION == "paged":
if V2:
def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
out, seqlen: Seqlen,
seqlen.cu_seqlen_q, block_tables: torch.Tensor,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale, softmax_scale,
False, window_size_left=-1,
causal, causal=True,
False, softcap=0.0,
0, ):
None, out = torch.empty_like(q)
) if window_size_left <= 0 and window_size_left != -1:
return out raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we # Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which # have a configuration that requires flash-attention v1, which
# does not support block tables. # does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE", "PREFILL_IN_KV_CACHE",
......
...@@ -699,7 +699,6 @@ def check_args( ...@@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
......
...@@ -66,6 +66,7 @@ def paged_attention( ...@@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(query) out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
...@@ -74,7 +75,7 @@ def paged_attention( ...@@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
seqlen.input_lengths, input_lengths,
BLOCK_SIZE, BLOCK_SIZE,
max_s, max_s,
None, None,
......
...@@ -104,7 +104,7 @@ def paged_attention( ...@@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query) out = torch.empty_like(query)
......
...@@ -76,6 +76,7 @@ class CausalLMBatch(Batch): ...@@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.input_ids),
) )
@classmethod @classmethod
......
...@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module): ...@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor, aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = ( (
pixel_values.shape batch_size,
) num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape( pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width batch_size * num_concurrent_media * num_tiles, num_channels, height, width
......
...@@ -5,9 +5,14 @@ from typing import Dict, Optional ...@@ -5,9 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"} _expected = {"paged", "flashdecoding", "flashinfer"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
...@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: ...@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
......
...@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod
......
...@@ -116,6 +116,7 @@ class MambaBatch(Batch): ...@@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod
......
from io import BytesIO
from PIL import Image
import torch import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import ( from text_generation_server.models.flash_causal_lm import (
...@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): ...@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1 max=config.text_config.vocab_size - 1
) )
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None: if image_inputs is not None:
...@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch): ...@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM): class MllamaCausalLM(VlmCausalLM):
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None, adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
...@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
...@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
...@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
...@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
...@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images. # Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None or batch.cross_attention_states is not None
): ):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
if batch.pixel_values is not None: if batch.pixel_values is not None:
...@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM): ...@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["cache_lengths"].zero_()
) cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
# Replay the graph with self._forward_context(
cuda_graph["graph"].replay() block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = ( speculative_logits = (
......
...@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
...@@ -31,6 +40,7 @@ class Model(ABC): ...@@ -31,6 +40,7 @@ class Model(ABC):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID, adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()
...@@ -60,6 +70,29 @@ class Model(ABC): ...@@ -60,6 +70,29 @@ class Model(ABC):
speculate = get_speculate() speculate = get_speculate()
self.speculate = speculate self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
...@@ -78,6 +111,10 @@ class Model(ABC): ...@@ -78,6 +111,10 @@ class Model(ABC):
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate, speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
) )
@property @property
......
...@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self.decoder_input_ids),
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment