"csrc/cpu/graclus.cpp" did not exist on "0a559a4a892809997ca9804c923530bb48d4452d"
Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
......@@ -74,6 +74,14 @@ class Tokens:
def __len__(self):
return len(self.token_ids)
def __add__(self, other: "Tokens") -> "Tokens":
return Tokens(
self.token_ids + other.token_ids,
self.logprobs + other.logprobs,
self.texts + other.texts,
self.is_special + other.is_special,
)
@dataclass
class Generation:
......
......@@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs,
)
......@@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
......@@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
......@@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
......@@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
......@@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM):
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
......@@ -411,22 +410,34 @@ class VlmCausalLM(FlashCausalLM):
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
cache_lengths=batch.cache_lengths,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph
cuda_graph["graph"].replay()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
......
......@@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
......@@ -46,9 +47,12 @@ class SignalHandler:
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def set_keep_processing(self, value: bool):
self.KEEP_PROCESSING = value
def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False
self.set_keep_processing(False)
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
......@@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
set_max_prefill_tokens(request.max_prefill_tokens)
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels
......@@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
concat_ns = None
if self.model.support_chunking:
if request.HasField("cached_batch"):
cached_batch = self.cache.pop(request.cached_batch.id)
if cached_batch is None:
raise ValueError(
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch)
......@@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
concat_ns=concat_ns,
)
async def Decode(self, request, context):
......@@ -252,10 +271,12 @@ def serve(
logger.exception("Error when initializing model")
raise
signal_handler = SignalHandler()
set_adapter_to_index(adapter_to_index)
server = aio.server(
interceptors=[
ExceptionInterceptor(),
ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
UDSOpenTelemetryAioServerInterceptor(),
],
options=[
......@@ -276,7 +297,6 @@ def serve(
await server.start()
logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
......
......@@ -120,15 +120,18 @@ def _load_and_merge(
if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (
load_module_map(
model_id,
adapter.revision,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_module_map(
model_id,
adapter.revision,
adapter.id,
adapter.path,
weight_names,
trust_remote_code,
)
adapters_to_merge.append((module_map, adapter_config))
......
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS
......@@ -7,6 +7,7 @@ from typing import List, Tuple, Union
import torch
# FIXME: this should be optimized
def find_segments(
adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment