batcher.rs 7.57 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
2
/// Batching and inference logic
use crate::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
use crate::{Db, Entry};
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
5
6
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
use std::sync::Arc;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
8
use std::time::Duration;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use tokio::time::Instant;
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
/// Batcher
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
16
pub struct Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    /// Request database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
    db: Db,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
    /// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
    shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
/// Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
    /// Batching background Tokio task notifier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
26
27
28
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
29
impl Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
32
33
34
35
    pub(crate) fn new(
        client: ShardedClient,
        max_batch_size: usize,
        max_waiting_time: Duration,
    ) -> Self {
        // Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
36
37
38
39
40
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
41
42
43
44
45
46
47
48
        // Spawn batching background task that contains all the inference logic
        tokio::spawn(batching_task(
            max_batch_size,
            max_waiting_time,
            client,
            db.clone(),
            shared.clone(),
        ));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
50
51
52

        Self { db, shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
    /// Add a new request to the database and return a future that will generate the text
Olivier Dehaene's avatar
Olivier Dehaene committed
54
55
56
57
58
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
    ) -> Result<String, InferError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
61
62
        // One shot channel to communicate with the background batching task
        let (response_tx, response_rx) = oneshot::channel();

        // Try to append the request to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
        self.db.append(Entry {
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
            response_tx,
Olivier Dehaene's avatar
Olivier Dehaene committed
66
            input_length,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            time: Instant::now(),
Olivier Dehaene's avatar
Olivier Dehaene committed
68
        });
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
70
71

        // Notify the background task that we have a new entry in the database that needs
        // to be batched
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72
        self.shared.batching_task.notify_waiters();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
74
75
76

        // Await on the response from the background task
        // We can safely unwrap as the background task will never drop the sender
        match response_rx.await.unwrap() {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
77
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
78
            Err(err) => Err(InferError::GenerationError(err.to_string())),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80
81
82
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
83
84
85
86
87
88
89
90
91
92
93
94
95
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
    max_batch_size: usize,
    max_waiting_time: Duration,
    client: ShardedClient,
    db: Db,
    shared: Arc<Shared>,
) {
    // Minimum batch size after which we try to add more requests
Olivier Dehaene's avatar
Olivier Dehaene committed
96
97
    let limit_min_batch_size = (max_batch_size / 2) as u32;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
98
    // Infinite loop
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
99
    loop {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
        // Wait for a notification from the Batcher struct
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
101
102
        shared.batching_task.notified().await;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
103
104
105
106
107
        // Get the next batch from the DB
        // This batch might be smaller than the maximum batch size if there are not enough requests
        // waiting in the DB
        if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
            let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
Olivier Dehaene's avatar
Olivier Dehaene committed
108

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
109
110
            // We loop until we do not receive any cached batch from the inference server (== until
            // all requests have met their stopping criteria)
Olivier Dehaene's avatar
Olivier Dehaene committed
111
            while let Some(batch) = cached_batch {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
113
                // Get current batch info
                let batch_size = batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
114
115
116
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
117
118
119
120
121
122
123
                // If the current batch is too small, we try to add more requests to it
                if batch_size <= limit_min_batch_size {
                    // Get the next batch from the DB that meet our minimum size criteria
                    if let Some((new_request_ids, new_batch)) =
                        db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
                    {
                        // Generate one token for this new batch to have the attention past in cache
Olivier Dehaene's avatar
Olivier Dehaene committed
124
                        let new_cached_batch =
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
                            wrap_future(client.generate(new_batch), new_request_ids, &db).await;
                        // Extend current batch with the new batch
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
                    // If we don't have enough requests to meet the minimum size criteria, we
                    // try to get the next batch from the DB that have been waiting over
                    // the max_waiting_time
                    else if let Some((new_request_ids, new_batch)) =
                        db.next_batch(None, max_batch_size, Some(max_waiting_time))
                    {
                        let new_cached_batch =
                            wrap_future(client.generate(new_batch), new_request_ids, &db).await;
                        // Extend current batch with the new batch
Olivier Dehaene's avatar
Olivier Dehaene committed
141
142
143
144
145
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
146
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
147

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
148
149
                cached_batch =
                    wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
150
151
152
153
154
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
155
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
Olivier Dehaene's avatar
Olivier Dehaene committed
156
157
158
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
159
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
160
161
162
163
164
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
165
        }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
        // If we have an error, we discard the whole batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
167
        Err(err) => {
Olivier Dehaene's avatar
Olivier Dehaene committed
168
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169
170
171
172
173
            None
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
174
/// Send errors to the Batcher for all `request_ids`
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
175
176
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
177
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
178
179
180
        let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
181
182
183
    });
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
184
/// Send `generated_text` to the Batcher for all `finished`
Olivier Dehaene's avatar
Olivier Dehaene committed
185
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
186
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
187
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
188
189
190
191
192
        let entry = db
            .remove(&output.request.unwrap().id)
            .expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Ok(output.output)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
193
194
    });
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
}

/// Convert to Axum supported format
impl From<InferError> for (StatusCode, String) {
    fn from(err: InferError) -> Self {
        match err {
            InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
        }
    }
}