Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
import pytest
import base64
import asyncio
......@@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
......
......@@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
......@@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
}
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
......@@ -1678,7 +1682,7 @@ fn main() -> Result<(), LauncherError> {
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
......@@ -1729,12 +1733,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
......@@ -1788,12 +1786,6 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
......
......@@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
}
/// Empty request
......@@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
}
message Batch {
......@@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {
......@@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
......@@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message DecodeRequest {
......
......@@ -18,45 +18,6 @@ use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
......
......@@ -2,7 +2,7 @@ import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
......
......@@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept(
self,
method: Callable,
......@@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
self.shutdown_callback()
if torch.cuda.is_available():
torch.cuda.empty_cache()
......
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
......@@ -123,7 +123,7 @@ def paged_attention(
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops
out = torch.empty_like(query)
......@@ -244,117 +244,232 @@ if ATTENTION == "flashinfer":
window_left=window_size_left,
)
elif V2:
elif ATTENTION == "flashdecoding":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
def attention(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [
"PREFILL_IN_KV_CACHE",
......
......@@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
......
......@@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
......@@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
input_lengths,
BLOCK_SIZE,
max_s,
None,
......
......@@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query)
......
......@@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self.input_ids),
)
@classmethod
......
......@@ -493,9 +493,14 @@ class MllamaVisionModel(nn.Module):
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
pixel_values.shape
)
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
......
......@@ -5,9 +5,14 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
......@@ -18,7 +23,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
......
......@@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod
......
......@@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod
......
from io import BytesIO
from PIL import Image
import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
PreTrainedTokenizerBase,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
......@@ -167,6 +170,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1
)
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None:
......@@ -190,7 +200,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
class MllamaCausalLM(VlmCausalLM):
def forward(
self,
batch: VlmCausalLMBatch,
batch: MllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
......@@ -202,7 +212,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
......@@ -221,8 +231,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
......@@ -244,8 +254,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
......@@ -254,7 +264,6 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
......@@ -269,26 +278,24 @@ class MllamaCausalLM(VlmCausalLM):
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
if batch.pixel_values is not None:
......@@ -330,22 +337,34 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
# Replay the graph
cuda_graph["graph"].replay()
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
......
......@@ -5,8 +5,17 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.globals import (
ATTENTION,
PREFIX_CACHING,
BLOCK_SIZE,
PREFILL_CHUNKING,
)
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
......@@ -31,6 +40,7 @@ class Model(ABC):
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
):
self.model_id = model_id
self.model = model.eval()
......@@ -60,6 +70,29 @@ class Model(ABC):
speculate = get_speculate()
self.speculate = speculate
support_chunking = support_chunking and PREFILL_CHUNKING
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
)
support_chunking = False
log_master(
logger.info, f"Using experimental prefill chunking = {support_chunking}"
)
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
......@@ -78,6 +111,10 @@ class Model(ABC):
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate,
support_chunking=self.support_chunking,
use_prefix_caching=PREFIX_CACHING,
attention_impl=ATTENTION,
block_size=BLOCK_SIZE,
)
@property
......
......@@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self.decoder_input_ids),
)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment