batcher.rs 7.84 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Batching and inference logic
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::{Db, Entry};
3
use crate::{ErrorResponse, GenerateRequest};
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use axum::http::StatusCode;
5
use axum::Json;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
use std::sync::Arc;
8
use text_generation_client::{Batch, ClientError, GeneratedText, ShardedClient};
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
11
12
use tokio::time::Instant;
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
/// Batcher
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
16
pub struct Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    /// Request database
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
    db: Db,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
    /// Shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
    shared: Arc<Shared>,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
/// Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
struct Shared {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
    /// Batching background Tokio task notifier
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
26
27
28
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
29
impl Batcher {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
32
    pub(crate) fn new(
        client: ShardedClient,
        max_batch_size: usize,
33
        max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
35
    ) -> Self {
        // Batcher shared state
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
36
37
38
39
40
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
41
42
        // Spawn batching background task that contains all the inference logic
        tokio::spawn(batching_task(
43
            client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
44
            max_batch_size,
45
            max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
47
48
            db.clone(),
            shared.clone(),
        ));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
50
51
52

        Self { db, shared }
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
    /// Add a new request to the database and return a future that will generate the text
Olivier Dehaene's avatar
Olivier Dehaene committed
54
55
56
57
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
58
    ) -> Result<InferResponse, InferError> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
60
61
62
        // One shot channel to communicate with the background batching task
        let (response_tx, response_rx) = oneshot::channel();

        // Try to append the request to the database
Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
        self.db.append(Entry {
            request,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
65
            response_tx,
Olivier Dehaene's avatar
Olivier Dehaene committed
66
            input_length,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            time: Instant::now(),
68
            batch_time: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
69
        });
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
71
72

        // Notify the background task that we have a new entry in the database that needs
        // to be batched
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
73
        self.shared.batching_task.notify_waiters();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
75
76
77

        // Await on the response from the background task
        // We can safely unwrap as the background task will never drop the sender
        match response_rx.await.unwrap() {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
79
            Err(err) => Err(InferError::GenerationError(err.to_string())),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
80
81
82
83
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
84
85
86
87
88
89
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, db, shared))]
async fn batching_task(
90
    mut client: ShardedClient,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
91
    max_batch_size: usize,
92
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
93
94
95
96
    db: Db,
    shared: Arc<Shared>,
) {
    // Minimum batch size after which we try to add more requests
Olivier Dehaene's avatar
Olivier Dehaene committed
97
98
    let limit_min_batch_size = (max_batch_size / 2) as u32;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
99
    // Infinite loop
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
100
    loop {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
        // Wait for a notification from the Batcher struct
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
102
103
        shared.batching_task.notified().await;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
105
106
        // Get the next batch from the DB
        // This batch might be smaller than the maximum batch size if there are not enough requests
        // waiting in the DB
107
        while let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
            let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
109
            let mut waiting_tokens = 1;
Olivier Dehaene's avatar
Olivier Dehaene committed
110

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
112
            // We loop until we do not receive any cached batch from the inference server (== until
            // all requests have met their stopping criteria)
Olivier Dehaene's avatar
Olivier Dehaene committed
113
            while let Some(batch) = cached_batch {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
115
                // Get current batch info
                let batch_size = batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
116
117
118
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
120
                // If the current batch is too small, we try to add more requests to it
                if batch_size <= limit_min_batch_size {
121
122
123
124
125
126
127
128
129
                    let min_size = match waiting_tokens {
                        // If we didn't onboard any new requests since >= max_waiting_tokens, we try
                        // to add a new batch even though its size might be small
                        _ if waiting_tokens >= max_waiting_tokens => None,
                        // Minimum size criteria
                        _ => Some(limit_min_batch_size as usize),
                    };

                    // Try to get a new batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
                    if let Some((new_request_ids, new_batch)) =
131
                        db.next_batch(min_size, max_batch_size)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
133
                    {
                        // Generate one token for this new batch to have the attention past in cache
Olivier Dehaene's avatar
Olivier Dehaene committed
134
                        let new_cached_batch =
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
135
                            wrap_future(client.generate(new_batch), new_request_ids, &db).await;
136
137
                        // Reset waiting counter
                        waiting_tokens = 1;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
139
140
141
142
143
                        // Extend current batch with the new batch
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
144
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
145

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
146
147
                cached_batch =
                    wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
148
                waiting_tokens += 1;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
152
153
            }
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
154
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
Olivier Dehaene's avatar
Olivier Dehaene committed
155
156
157
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
158
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
159
160
161
162
163
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
164
        }
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
        // If we have an error, we discard the whole batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
166
        Err(err) => {
Olivier Dehaene's avatar
Olivier Dehaene committed
167
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
168
169
170
171
172
            None
        }
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
173
/// Send errors to the Batcher for all `request_ids`
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
174
175
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
176
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
177
178
179
        let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
180
181
182
    });
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
183
/// Send `generated_text` to the Batcher for all `finished`
Olivier Dehaene's avatar
Olivier Dehaene committed
184
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
185
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
186
        // We can `expect` here as the request id should always be in the DB
Olivier Dehaene's avatar
Olivier Dehaene committed
187
188
189
        let entry = db
            .remove(&output.request.unwrap().id)
            .expect("ID not found in db. This is a bug.");
190
191
        let response = InferResponse {
            output: output.output,
192
            tokens: output.tokens,
193
194
195
196
            queued: entry.time,
            start: entry.batch_time.unwrap(), // unwrap is always valid
            end: Instant::now(),
        };
Olivier Dehaene's avatar
Olivier Dehaene committed
197
        // unwrap_or is valid here as we don't care if the receiver is gone.
198
        entry.response_tx.send(Ok(response)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
199
200
    });
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
201

202
203
204
#[derive(Debug)]
pub(crate) struct InferResponse {
    pub(crate) output: String,
205
    pub(crate) tokens: u32,
206
207
208
209
210
    pub(crate) queued: Instant,
    pub(crate) start: Instant,
    pub(crate) end: Instant,
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
211
212
213
214
215
216
217
#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
}

/// Convert to Axum supported format
218
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
219
220
    fn from(err: InferError) -> Self {
        match err {
221
            InferError::GenerationError(_) => (
222
                StatusCode::FAILED_DEPENDENCY,
223
224
225
226
                Json(ErrorResponse {
                    error: err.to_string(),
                }),
            ),
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
227
228
229
        }
    }
}