Unverified Commit 8aece3bd authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: move allocation logic to rust (#1835)

Close #2007
parent 9ffe1f1e
......@@ -25,11 +25,6 @@ from text_generation_server.models.types import (
Generation,
GeneratedText,
)
from text_generation_server.models.cache_manager import (
get_cache_manager,
set_cache_manager,
BLOCK_SIZE,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
......@@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 16
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
def get_sliding_windows() -> int:
global SLIDING_WINDOW
return SLIDING_WINDOW
@dataclass
class FlashCausalLMBatch(Batch):
......@@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
speculative_ids: torch.Tensor
speculative_ids: Optional[torch.Tensor]
# Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Paged Attention values
......@@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]]
block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor]
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor]
slots: torch.Tensor
max_seqlen: int
......@@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch
blocks: int
num_blocks: int
# Maximum number of blocks
max_blocks: int
......@@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.blocks * BLOCK_SIZE,
max_tokens=self.num_blocks * BLOCK_SIZE,
)
@classmethod
......@@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
return batch_tokenized_inputs
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
......@@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows()
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
prefill_cache_indices = []
input_lengths = []
prefix_offsets = []
......@@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
num_blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
block_tables = []
slots = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
......@@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
# blocks and slots can be empty (for example in warmup)
if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
......@@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
......@@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_blocks = max(max_blocks, len(request_blocks))
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
......@@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
......@@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64
)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
return cls(
batch_id=pb.id,
requests=pb.requests,
......@@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
prefill_cache_indices=prefill_cache_indices,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
......@@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=None,
)
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0:
......@@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
top_n_tokens = []
blocks = 0
num_blocks = 0
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
......@@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
)
request_block_table = self.block_tables[idx]
blocks += len(request_block_table)
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
......@@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table))
block_indices_to_free = []
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
block_indices_to_free.extend(self.block_tables[i])
# Free blocks
get_cache_manager().free(block_indices_to_free)
# Needed to avoid dropping blocks when the batches will go out of scope
self.block_tables = None
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
......@@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
......@@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
......@@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
blocks = 0
num_blocks = 0
total_batch_size = 0
total_slots = 0
max_blocks = 0
......@@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
blocks += b.blocks
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
......@@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
else None
)
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
del b
return cls(
batch_id=batches[0].batch_id,
requests=requests,
......@@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
......@@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
num_blocks=num_blocks,
max_blocks=max_blocks,
speculative_ids=speculative_ids,
)
def __del__(self):
if self.block_tables is not None and self.block_tables:
# Free blocks
get_cache_manager().free(
list(itertools.chain.from_iterable(self.block_tables))
)
def __len__(self):
return len(self.requests)
......@@ -702,6 +732,7 @@ class FlashCausalLM(Model):
self.head_size = head_size
self.cuda_graphs = {}
self.kv_cache = []
super(FlashCausalLM, self).__init__(
model=model,
......@@ -718,6 +749,43 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch
def max_past(self) -> int:
return getattr(self.model, "max_past", None)
def init_kv_cache(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.kv_cache = []
empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
......@@ -728,12 +796,11 @@ class FlashCausalLM(Model):
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
......@@ -747,11 +814,12 @@ class FlashCausalLM(Model):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
......@@ -761,11 +829,12 @@ class FlashCausalLM(Model):
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
......@@ -777,17 +846,16 @@ class FlashCausalLM(Model):
empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,
self.init_kv_cache(
batch.num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.sliding_window is not None,
self.dtype,
self.device,
)
max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size
max_s = max_bt * BLOCK_SIZE
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
......@@ -811,19 +879,17 @@ class FlashCausalLM(Model):
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
)
del batch
del cache_manager
set_cache_manager(
self.init_kv_cache(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.sliding_window is not None,
self.dtype,
self.device,
)
......@@ -889,7 +955,6 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
......@@ -901,12 +966,13 @@ class FlashCausalLM(Model):
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
kv_cache=self.kv_cache,
block_tables=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
)
def forward(
......@@ -917,7 +983,7 @@ class FlashCausalLM(Model):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
......@@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
......@@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
......@@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
......@@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots:
# Allocate blocks to this batch
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
batch.needed_blocks_slots,
batch.blocks,
batch.max_blocks,
batch.input_ids.device,
)
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor
batch.slots = slots
try:
out, speculative_logits = self.forward(batch)
except Exception as e:
del batch
raise e
out, speculative_logits = self.forward(batch)
if prefill:
next_token_logits = (
......@@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
......
import math
import torch
import torch.distributed
import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Type
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
HeterogeneousNextTokenChooser,
StoppingCriteria,
)
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW = sliding_window
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
def get_sliding_windows() -> Tuple[int, int]:
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor] = None
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
prefill_cache_indices = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if sliding_window_blocks is not None:
needed_blocks = min(needed_blocks, sliding_window_blocks)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None,
)
tracer = trace.get_tracer(__name__)
class BaseFlashMistral(FlashCausalLM):
......@@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
......@@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size,
)
def max_past(self) -> int:
return self.model.max_past
@property
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
)
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral):
def __init__(
......
......@@ -7,7 +7,6 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
......@@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
......
......@@ -6,7 +6,6 @@ from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
......@@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows
if config.sliding_window is not None:
set_sliding_window(
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
......
......@@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
tracer = trace.get_tracer(__name__)
......@@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image
class VlmCausalLMBatch(FlashMistralBatch):
class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
......@@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
......@@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment