Unverified Commit a6a0c97e authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: prefill chunking (#2600)



* wip

* rollback

* refactor to use prefix/postfix namming + fix all_input_ids_tensor

* maybe patching vlms?

* fix filter and concat

* wip, no filter, no concat

* current

* add prepare_for_prefill

* working

* load tested

* re-create slots

* re-create slots

* fix slot_filtering_indices

* feedback loop

* remove log

* fix benchmarker

* fix vlm and seq2seq

* rename to cache and input lengths

* fix prefill logprobs

* fix launcher

* fix logprobs?

* idk at this point

* max input length

* omfg

* remove debugging lines

* fix tests

* fix mllama

* fix cargo tests

* remove support chunking for paged

* Fixing non blocked attentions

* Fixing dtype + AMD, Ipex targets.

* lint fix.

* rename

* Fix prefix_caching variable, remove defaults in server (confusing a lot
of the times).

* Add simple resolution when user specifies ATTENTION=paged.

* Put back non default simple tests.

* Fix env name

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 704a58c8
...@@ -74,6 +74,14 @@ class Tokens: ...@@ -74,6 +74,14 @@ class Tokens:
def __len__(self): def __len__(self):
return len(self.token_ids) return len(self.token_ids)
def __add__(self, other: "Tokens") -> "Tokens":
return Tokens(
self.token_ids + other.token_ids,
self.logprobs + other.logprobs,
self.texts + other.texts,
self.is_special + other.is_special,
)
@dataclass @dataclass
class Generation: class Generation:
......
...@@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM): ...@@ -271,6 +271,8 @@ class VlmCausalLM(FlashCausalLM):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs, **kwargs,
) )
...@@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM): ...@@ -295,7 +297,7 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
...@@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM): ...@@ -314,8 +316,8 @@ class VlmCausalLM(FlashCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
...@@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM): ...@@ -337,8 +339,8 @@ class VlmCausalLM(FlashCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
...@@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM): ...@@ -347,7 +349,6 @@ class VlmCausalLM(FlashCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
# Try to find an associated cuda graph # Try to find an associated cuda graph
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
...@@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM): ...@@ -357,26 +358,24 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer":
if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s, max_q=batch.max_input_length,
max_k=max_k, max_k=batch.max_current_length,
) )
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -411,22 +410,34 @@ class VlmCausalLM(FlashCausalLM): ...@@ -411,22 +410,34 @@ class VlmCausalLM(FlashCausalLM):
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
input_lengths=batch.input_lengths, input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens, cache_lengths=batch.cache_lengths,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: else:
cuda_graph["block_tables"][ cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1] : block_tables.shape[0], : block_tables.shape[1]
] = block_tables ] = block_tables
cuda_graph["slots"].fill_(-1)
# XXX: This is working only because block 0 is reserved for the healthcheck
# so it doesn't matter if we override it with bogus values.
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
input_lengths + prefix_lens_tensor cuda_graph["cache_lengths"].zero_()
) cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
# Replay the graph ] = cache_lengths_tensor
cuda_graph["graph"].replay()
with self._forward_context(
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths_tensor=cuda_graph["input_lengths"],
cache_lengths_tensor=cuda_graph["cache_lengths"],
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = ( speculative_logits = (
......
...@@ -15,6 +15,7 @@ from text_generation_server.cache import Cache ...@@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
...@@ -46,9 +47,12 @@ class SignalHandler: ...@@ -46,9 +47,12 @@ class SignalHandler:
signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully)
def set_keep_processing(self, value: bool):
self.KEEP_PROCESSING = value
def exit_gracefully(self, signum, frame): def exit_gracefully(self, signum, frame):
print(f"Exiting gracefully: Signal {signum}") print(f"Exiting gracefully: Signal {signum}")
self.KEEP_PROCESSING = False self.set_keep_processing(False)
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...@@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -96,6 +100,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
set_max_prefill_tokens(request.max_prefill_tokens)
if self.quantize in {"exl2", "gptq"}: if self.quantize in {"exl2", "gptq"}:
try: try:
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels
...@@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -150,6 +156,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
concat_ns = None
if self.model.support_chunking:
if request.HasField("cached_batch"):
cached_batch = self.cache.pop(request.cached_batch.id)
if cached_batch is None:
raise ValueError(
f"Batch ID {request.cached_batch.id} not found in cache."
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([cached_batch, batch])
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
...@@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -159,6 +177,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
forward_ns=timings[0], forward_ns=timings[0],
decode_ns=timings[1], decode_ns=timings[1],
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
concat_ns=concat_ns,
) )
async def Decode(self, request, context): async def Decode(self, request, context):
...@@ -252,10 +271,12 @@ def serve( ...@@ -252,10 +271,12 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
signal_handler = SignalHandler()
set_adapter_to_index(adapter_to_index) set_adapter_to_index(adapter_to_index)
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),
UDSOpenTelemetryAioServerInterceptor(), UDSOpenTelemetryAioServerInterceptor(),
], ],
options=[ options=[
...@@ -276,7 +297,6 @@ def serve( ...@@ -276,7 +297,6 @@ def serve(
await server.start() await server.start()
logger.info("Server started at {}".format(local_url)) logger.info("Server started at {}".format(local_url))
signal_handler = SignalHandler()
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
......
...@@ -120,15 +120,18 @@ def _load_and_merge( ...@@ -120,15 +120,18 @@ def _load_and_merge(
if adapter.id == BASE_MODEL_ADAPTER_ID: if adapter.id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.") raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = ( (
load_module_map( module_map,
model_id, adapter_config,
adapter.revision, adapter_weight_names,
adapter.id, adapter_tokenizer,
adapter.path, ) = load_module_map(
weight_names, model_id,
trust_remote_code, adapter.revision,
) adapter.id,
adapter.path,
weight_names,
trust_remote_code,
) )
adapters_to_merge.append((module_map, adapter_config)) adapters_to_merge.append((module_map, adapter_config))
......
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS
...@@ -7,6 +7,7 @@ from typing import List, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import List, Tuple, Union
import torch import torch
# FIXME: this should be optimized
def find_segments( def find_segments(
adapter_indices: Union[torch.Tensor, List[int]] adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment