Unverified Commit 3b56d766 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add mistral model (#1071)

parent 259a2300
......@@ -68,6 +68,7 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
Other architectures are supported on a best effort basis using:
......
......@@ -140,6 +140,8 @@ class Parameters:
watermark: bool
# Get decoder input token logprobs and ids
decoder_input_details: bool
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
# Decoder input tokens
class InputToken:
......@@ -189,6 +191,8 @@ class BestOfSequence:
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details
......@@ -203,6 +207,8 @@ class Details:
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
......@@ -229,6 +235,8 @@ class StreamDetails:
class StreamResponse:
# Generated token
token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
......
This diff is collapsed.
[tool.poetry]
name = "text-generation"
version = "0.6.0"
version = "0.6.1"
description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
......
......@@ -482,7 +482,6 @@ class AsyncClient:
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
......
......@@ -40,7 +40,7 @@ class Parameters(BaseModel):
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
top_n_tokens: Optional[int] = None
@validator("best_of")
def valid_best_of(cls, field_value, values):
......@@ -188,7 +188,7 @@ class BestOfSequence(BaseModel):
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
top_tokens: Optional[List[List[Token]]] = None
# `generate` details
......@@ -204,7 +204,7 @@ class Details(BaseModel):
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
top_tokens: Optional[List[List[Token]]] = None
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]] = None
......@@ -232,7 +232,7 @@ class StreamResponse(BaseModel):
# Generated token
token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
top_tokens: Optional[List[Token]] = None
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str] = None
......
......@@ -34,10 +34,17 @@ Options:
[env: NUM_SHARD=]
--quantize <QUANTIZE>
Whether you want the model to be quantized. This will use `bitsandbytes` for quantization on the fly, or `gptq`. 4bit quantization is available through `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options
Whether you want the model to be quantized
[env: QUANTIZE=]
[possible values: bitsandbytes, bitsandbytes-nf4, bitsandbytes-fp4, gptq, awq]
Possible values:
- awq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models whereever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels whereever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
--dtype <DTYPE>
The dtype to be forced upon the model. This option cannot be used with `--quantize`
......
......@@ -18,6 +18,8 @@ The following models are optimized and can be served with TGI, which uses custom
- [Falcon 40B](https://huggingface.co/tiiuae/falcon-40b)
- [MPT](https://huggingface.co/mosaicml/mpt-30b)
- [Llama V2](https://huggingface.co/meta-llama)
- [Code Llama](https://huggingface.co/codellama)
- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 28747,
"logprob": -0.54785156,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -1.4091797,
"special": false,
"text": " Let"
},
{
"id": 307,
"logprob": -3.0273438,
"special": false,
"text": " n"
},
{
"id": 327,
"logprob": -0.94433594,
"special": false,
"text": " ="
},
{
"id": 28705,
"logprob": -0.81347656,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.2958984,
"special": false,
"text": "1"
},
{
"id": 28734,
"logprob": -2.0644531,
"special": false,
"text": "0"
},
{
"id": 387,
"logprob": -1.9580078,
"special": false,
"text": " -"
},
{
"id": 28705,
"logprob": -0.5073242,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.1816406,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": ": Let n = 10 - 1"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 28747,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -0.1307373,
"special": false,
"text": " Let"
},
{
"id": 332,
"logprob": -2.3359375,
"special": false,
"text": " u"
},
{
"id": 347,
"logprob": 0.0,
"special": false,
"text": " be"
},
{
"id": 325,
"logprob": -1.0234375,
"special": false,
"text": " ("
},
{
"id": 28734,
"logprob": -2.0292969,
"special": false,
"text": "0"
},
{
"id": 648,
"logprob": -1.0439453,
"special": false,
"text": " +"
},
{
"id": 28705,
"logprob": -0.24499512,
"special": false,
"text": " "
},
{
"id": 28770,
"logprob": -0.5073242,
"special": false,
"text": "3"
},
{
"id": 387,
"logprob": -1.5507812,
"special": false,
"text": " -"
}
],
"top_tokens": null
},
"generated_text": "Test request: Let u be (0 + 3 -"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 28747,
"logprob": -0.55078125,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -1.4140625,
"special": false,
"text": " Let"
},
{
"id": 307,
"logprob": -3.0273438,
"special": false,
"text": " n"
},
{
"id": 327,
"logprob": -0.94140625,
"special": false,
"text": " ="
},
{
"id": 28705,
"logprob": -0.8173828,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.2978516,
"special": false,
"text": "1"
},
{
"id": 28734,
"logprob": -2.0664062,
"special": false,
"text": "0"
},
{
"id": 387,
"logprob": -1.9560547,
"special": false,
"text": " -"
},
{
"id": 28705,
"logprob": -0.5078125,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.1787109,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": ": Let n = 10 - 1"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 28747,
"logprob": -0.54785156,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -1.4111328,
"special": false,
"text": " Let"
},
{
"id": 307,
"logprob": -3.0292969,
"special": false,
"text": " n"
},
{
"id": 327,
"logprob": -0.94433594,
"special": false,
"text": " ="
},
{
"id": 28705,
"logprob": -0.8178711,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.2939453,
"special": false,
"text": "1"
},
{
"id": 28734,
"logprob": -2.0644531,
"special": false,
"text": "0"
},
{
"id": 387,
"logprob": -1.9550781,
"special": false,
"text": " -"
},
{
"id": 28705,
"logprob": -0.5078125,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.1796875,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": ": Let n = 10 - 1"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 28747,
"logprob": -0.55078125,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -1.4140625,
"special": false,
"text": " Let"
},
{
"id": 307,
"logprob": -3.0273438,
"special": false,
"text": " n"
},
{
"id": 327,
"logprob": -0.94140625,
"special": false,
"text": " ="
},
{
"id": 28705,
"logprob": -0.8173828,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.2978516,
"special": false,
"text": "1"
},
{
"id": 28734,
"logprob": -2.0664062,
"special": false,
"text": "0"
},
{
"id": 387,
"logprob": -1.9560547,
"special": false,
"text": " -"
},
{
"id": 28705,
"logprob": -0.5078125,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.1787109,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": ": Let n = 10 - 1"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -12.9140625,
"text": "Test"
},
{
"id": 2159,
"logprob": -10.7578125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 28747,
"logprob": -0.55078125,
"special": false,
"text": ":"
},
{
"id": 3169,
"logprob": -1.4140625,
"special": false,
"text": " Let"
},
{
"id": 307,
"logprob": -3.0273438,
"special": false,
"text": " n"
},
{
"id": 327,
"logprob": -0.94140625,
"special": false,
"text": " ="
},
{
"id": 28705,
"logprob": -0.8173828,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.2978516,
"special": false,
"text": "1"
},
{
"id": 28734,
"logprob": -2.0664062,
"special": false,
"text": "0"
},
{
"id": 387,
"logprob": -1.9560547,
"special": false,
"text": " -"
},
{
"id": 28705,
"logprob": -0.5078125,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -1.1787109,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": ": Let n = 10 - 1"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_mistral_handle(launcher):
with launcher("mistralai/Mistral-7B-Instruct-v0.1") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mistral(flash_mistral_handle):
await flash_mistral_handle.health(300)
return flash_mistral_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
responses = await generate_load(
flash_mistral, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
......@@ -31,6 +31,7 @@ message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
}
/// Empty request
......
......@@ -50,10 +50,11 @@ impl Infer {
max_waiting_tokens: usize,
max_concurrent_requests: usize,
requires_padding: bool,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16);
let queue = Queue::new(requires_padding, 16, window_size);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
......
......@@ -2,6 +2,7 @@ use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::ValidGenerateRequest;
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::{Batch, Request};
use tokio::sync::oneshot;
......@@ -33,12 +34,17 @@ pub(crate) struct Queue {
}
impl Queue {
pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self {
pub(crate) fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
// Create channel
let (queue_sender, queue_receiver) = flume::unbounded();
// Launch background queue task
tokio::spawn(queue_task(requires_padding, block_size, queue_receiver));
tokio::spawn(queue_task(
requires_padding,
block_size,
window_size,
queue_receiver,
));
Self { queue_sender }
}
......@@ -84,9 +90,10 @@ impl Queue {
async fn queue_task(
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
receiver: flume::Receiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size);
let mut state = State::new(requires_padding, block_size, window_size);
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
......@@ -126,16 +133,20 @@ struct State {
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
}
impl State {
fn new(requires_padding: bool, block_size: u32) -> Self {
fn new(requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
Self {
entries: VecDeque::with_capacity(128),
next_id: 0,
next_batch_id: 0,
requires_padding,
block_size,
window_size,
}
}
......@@ -204,11 +215,17 @@ impl State {
if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
// pad to block size
decode_tokens +=
((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1)
/ self.block_size)
* self.block_size;
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
}
if prefill_tokens > prefill_token_budget
......@@ -342,7 +359,7 @@ mod tests {
#[test]
fn test_append() {
let mut state = State::new(false, 1);
let mut state = State::new(false, 1, None);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
......@@ -358,7 +375,7 @@ mod tests {
#[test]
fn test_next_batch_empty() {
let mut state = State::new(false, 1);
let mut state = State::new(false, 1, None);
assert!(state.next_batch(None, 1, 1).is_none());
assert!(state.next_batch(Some(1), 1, 1).is_none());
......@@ -366,7 +383,7 @@ mod tests {
#[test]
fn test_next_batch_min_size() {
let mut state = State::new(false, 1);
let mut state = State::new(false, 1, None);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -398,7 +415,7 @@ mod tests {
#[test]
fn test_next_batch_token_budget() {
let mut state = State::new(false, 1);
let mut state = State::new(false, 1, None);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
......@@ -431,14 +448,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1);
let queue = Queue::new(false, 1, None);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1);
let queue = Queue::new(false, 1, None);
assert!(queue.next_batch(None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), 1, 1).await.is_none());
......@@ -446,7 +463,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1);
let queue = Queue::new(false, 1, None);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -479,7 +496,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1);
let queue = Queue::new(false, 1, None);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
......@@ -504,7 +521,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1);
let queue = Queue::new(false, 1, None);
let (entry, _) = default_entry();
queue.append(entry);
......
......@@ -595,6 +595,7 @@ pub async fn run(
max_waiting_tokens,
max_concurrent_requests,
shard_info.requires_padding,
shard_info.window_size,
generation_health,
);
......
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
flash-attention-v2:
# Clone flash attention
......
vllm_commit := e86af624d059969b0fb07b075b1d338bf10c3365
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
vllm:
# Clone vllm
......
......@@ -67,6 +67,16 @@ if FLASH_ATTENTION:
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded)
MISTRAL = True
try:
from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
logger.warning(f"Could not import Mistral model: {e}")
MISTRAL = False
if MISTRAL:
__all__.append(FlashMistral)
def get_model(
model_id: str,
......@@ -237,7 +247,18 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "opt":
if model_type == "mistral":
if MISTRAL:
return FlashMistral(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Mistral model requires flash attention v2")
if model_type == "opt":
return OPTSharded(
model_id,
revision,
......@@ -246,7 +267,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == "t5":
if model_type == "t5":
return T5Sharded(
model_id,
revision,
......@@ -254,7 +275,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "idefics":
if model_type == "idefics":
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
......
import math
import torch
from typing import Optional, List, Tuple
BLOCK_SIZE: int = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks
self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(
self,
needed_blocks_slots: List[Tuple[int, int]],
blocks: int,
max_blocks: int,
device: torch.device,
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert (
len(free_block_indices) >= blocks
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
# Slice by the number of required blocks
block_indices = free_block_indices[:blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
all_slots = self.slots[allocated_blocks].flatten()
# Repeat slots in the case of context sliding window
if needed_slots > len(all_slots) and self.repeat_slots:
repeats = math.ceil(needed_slots / len(all_slots))
all_slots = all_slots.repeat(repeats)
allocated_slots = all_slots[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
block_tables = block_tables
block_tables_tensor = block_tables_tensor.to(device)
slots = torch.concat(slots).to(device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
return block_tables, block_tables_tensor, slots
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
def set_cache_manager(
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
) -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is not None:
del CACHE_MANAGER
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
)
return CACHE_MANAGER
def get_cache_manager() -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is None:
raise RuntimeError("cache manager was not initialized")
return CACHE_MANAGER
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment