/// Batching and inference logic use crate::GenerateRequest; use crate::{Db, Entry}; use axum::http::StatusCode; use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient}; use std::future::Future; use std::sync::Arc; use std::time::Duration; use thiserror::Error; use tokio::sync::{oneshot, Notify}; use tokio::time::Instant; use tracing::instrument; /// Batcher #[derive(Clone)] pub struct Batcher { /// Request database db: Db, /// Shared state shared: Arc, } /// Batcher shared state struct Shared { /// Batching background Tokio task notifier batching_task: Notify, } impl Batcher { pub(crate) fn new( client: ShardedClient, max_batch_size: usize, max_waiting_time: Duration, ) -> Self { // Batcher shared state let db = Db::new(); let shared = Arc::new(Shared { batching_task: Notify::new(), }); // Spawn batching background task that contains all the inference logic tokio::spawn(batching_task( max_batch_size, max_waiting_time, client, db.clone(), shared.clone(), )); Self { db, shared } } /// Add a new request to the database and return a future that will generate the text pub(crate) async fn infer( &self, input_length: usize, request: GenerateRequest, ) -> Result { // One shot channel to communicate with the background batching task let (response_tx, response_rx) = oneshot::channel(); // Try to append the request to the database self.db.append(Entry { request, response_tx, input_length, time: Instant::now(), }); // Notify the background task that we have a new entry in the database that needs // to be batched self.shared.batching_task.notify_waiters(); // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender match response_rx.await.unwrap() { Ok(output) => Ok(output), Err(err) => Err(InferError::GenerationError(err.to_string())), } } } /// Batching logic /// Will be launched in a background Tokio task /// /// Batches requests and sends them to the inference server #[instrument(skip(client, db, shared))] async fn batching_task( max_batch_size: usize, max_waiting_time: Duration, client: ShardedClient, db: Db, shared: Arc, ) { // Minimum batch size after which we try to add more requests let limit_min_batch_size = (max_batch_size / 2) as u32; // Infinite loop loop { // Wait for a notification from the Batcher struct shared.batching_task.notified().await; // Get the next batch from the DB // This batch might be smaller than the maximum batch size if there are not enough requests // waiting in the DB if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) { let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await; // We loop until we do not receive any cached batch from the inference server (== until // all requests have met their stopping criteria) while let Some(batch) = cached_batch { // Get current batch info let batch_size = batch.size; let mut request_ids: Vec = batch.requests.iter().map(|req| req.id).collect(); let mut batches = vec![batch]; // If the current batch is too small, we try to add more requests to it if batch_size <= limit_min_batch_size { // Get the next batch from the DB that meet our minimum size criteria if let Some((new_request_ids, new_batch)) = db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None) { // Generate one token for this new batch to have the attention past in cache let new_cached_batch = wrap_future(client.generate(new_batch), new_request_ids, &db).await; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); batches.push(new_cached_batch); } } // If we don't have enough requests to meet the minimum size criteria, we // try to get the next batch from the DB that have been waiting over // the max_waiting_time else if let Some((new_request_ids, new_batch)) = db.next_batch(None, max_batch_size, Some(max_waiting_time)) { let new_cached_batch = wrap_future(client.generate(new_batch), new_request_ids, &db).await; // Extend current batch with the new batch if let Some(new_cached_batch) = new_cached_batch { request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id)); batches.push(new_cached_batch); } } } cached_batch = wrap_future(client.generate_with_cache(batches), request_ids, &db).await; } } } } /// Wrap a future inside a match statement to handle errors and send the response to the Batcher async fn wrap_future( future: impl Future, Option), ClientError>>, request_ids: Vec, db: &Db, ) -> Option { match future.await { Ok((generated_texts, next_batch)) => { send_generated(generated_texts, db); next_batch } // If we have an error, we discard the whole batch Err(err) => { send_error(err, request_ids, db); None } } } /// Send errors to the Batcher for all `request_ids` fn send_error(error: ClientError, request_ids: Vec, db: &Db) { request_ids.into_iter().for_each(|id| { // We can `expect` here as the request id should always be in the DB let entry = db.remove(&id).expect("ID not found in db. This is a bug."); // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Err(error.clone())).unwrap_or(()); }); } /// Send `generated_text` to the Batcher for all `finished` fn send_generated(finished: Vec, db: &Db) { finished.into_iter().for_each(|output| { // We can `expect` here as the request id should always be in the DB let entry = db .remove(&output.request.unwrap().id) .expect("ID not found in db. This is a bug."); // unwrap_or is valid here as we don't care if the receiver is gone. entry.response_tx.send(Ok(output.output)).unwrap_or(()); }); } #[derive(Debug, Error)] pub enum InferError { #[error("Request failed during generation: {0}")] GenerationError(String), } /// Convert to Axum supported format impl From for (StatusCode, String) { fn from(err: InferError) -> Self { match err { InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()), } } }