batcher.rs 5.39 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::server::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
use crate::{Db, Entry};
use axum::http::StatusCode;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
5
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use thiserror::Error;
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11

const MAX_LENGTH: usize = 128;

Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
    #[error("Model is overloaded")]
    Overloaded,
}

impl From<InferError> for (StatusCode, String) {
    fn from(err: InferError) -> Self {
        match err {
            InferError::GenerationError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()),
            InferError::Overloaded => (StatusCode::TOO_MANY_REQUESTS, err.to_string()),
        }
    }
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
28
29

#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
30
pub struct Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31
32
33
34
35
36
37
38
    db: Db,
    shared: Arc<Shared>,
}

struct Shared {
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
39
impl Batcher {
Olivier Dehaene's avatar
Olivier Dehaene committed
40
    pub(crate) fn new(client: ShardedClient, max_batch_size: usize) -> Self {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
43
44
45
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

Olivier Dehaene's avatar
Olivier Dehaene committed
46
        tokio::spawn(batching_task(max_batch_size, client, db.clone(), shared.clone()));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
47
48
49
50

        Self { db, shared }
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
51
52
53
54
55
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
    ) -> Result<String, InferError> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
        if self.db.len() > MAX_LENGTH {
Olivier Dehaene's avatar
Olivier Dehaene committed
57
            return Err(InferError::Overloaded);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
58
59
        }
        let (request_tx, request_rx) = oneshot::channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
60
61
62
63
64
        self.db.append(Entry {
            request,
            response_tx: request_tx,
            input_length,
        });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
65
66
67
        self.shared.batching_task.notify_waiters();
        match request_rx.await.unwrap() {
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
68
            Err(err) => Err(InferError::GenerationError(err.to_string())),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
69
70
71
72
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
73
74
75
76
77
78
async fn batching_task(max_batch_size: usize,
                       client: ShardedClient,
                       db: Db,
                       shared: Arc<Shared>) {
    let limit_min_batch_size = (max_batch_size / 2) as u32;

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80
81
    loop {
        shared.batching_task.notified().await;

Olivier Dehaene's avatar
Olivier Dehaene committed
82
        if let Some(batch) = db.next_batch(max_batch_size) {
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
            let request_ids = batch.requests.iter().map(|req| req.id).collect();
            let mut cached_batch = match batch.size {
Olivier Dehaene's avatar
Olivier Dehaene committed
85
                size if size > limit_min_batch_size => {
Olivier Dehaene's avatar
Olivier Dehaene committed
86
87
88
89
90
91
                    wrap_future(client.generate_until_finished(batch), request_ids, &db).await
                }
                _ => wrap_future(client.generate(batch), request_ids, &db).await,
            };

            while let Some(batch) = cached_batch {
Olivier Dehaene's avatar
Olivier Dehaene committed
92
                let mut current_batch_size = batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
93
94
95
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

Olivier Dehaene's avatar
Olivier Dehaene committed
96
97
                if current_batch_size <= limit_min_batch_size {
                    if let Some(new_batch) = db.next_batch_minimum_size(limit_min_batch_size as usize, max_batch_size) {
Olivier Dehaene's avatar
Olivier Dehaene committed
98
99
100
101
102
103
                        let new_batch_request_ids =
                            new_batch.requests.iter().map(|req| req.id).collect();
                        let new_cached_batch =
                            wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
                                .await;
                        if let Some(new_cached_batch) = new_cached_batch {
Olivier Dehaene's avatar
Olivier Dehaene committed
104
                            current_batch_size += new_cached_batch.size;
Olivier Dehaene's avatar
Olivier Dehaene committed
105
106
107
108
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
109
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
110

Olivier Dehaene's avatar
Olivier Dehaene committed
111
112
                cached_batch = match current_batch_size {
                    size if size > limit_min_batch_size => {
Olivier Dehaene's avatar
Olivier Dehaene committed
113
114
115
116
117
118
                        wrap_future(
                            client.generate_until_finished_with_cache(batches),
                            request_ids,
                            &db,
                        )
                        .await
Olivier Dehaene's avatar
Olivier Dehaene committed
119
120
121
                    }
                    _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
                };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
122
123
124
125
126
            }
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
127
128
129
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
130
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
131
132
133
134
135
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
136
137
        }
        Err(err) => {
Olivier Dehaene's avatar
Olivier Dehaene committed
138
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
139
140
141
142
143
144
145
            None
        }
    }
}

fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
Olivier Dehaene's avatar
Olivier Dehaene committed
146
147
148
        let entry = db.remove(&id).expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Err(error.clone())).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
149
150
151
    });
}

Olivier Dehaene's avatar
Olivier Dehaene committed
152
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
153
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
Olivier Dehaene committed
154
155
156
157
158
        let entry = db
            .remove(&output.request.unwrap().id)
            .expect("ID not found in db. This is a bug.");
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry.response_tx.send(Ok(output.output)).unwrap_or(());
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
159
160
    });
}