batcher.rs 4.18 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::server::GenerateRequest;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::Db;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
4
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use tokio::sync::{oneshot, Notify};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
9
10
11
12

const MAX_LENGTH: usize = 128;

pub struct InferError {}

#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
13
pub(crate) struct Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
16
17
18
19
20
21
    db: Db,
    shared: Arc<Shared>,
}

struct Shared {
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
22
impl Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
25
26
27
28
29
30
31
32
33
    pub(crate) fn new(client: ShardedClient) -> Self {
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

        tokio::spawn(batching_task(client, db.clone(), shared.clone()));

        Self { db, shared }
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
37
38
    pub(crate) async fn infer(
        &self,
        input_length: usize,
        request: GenerateRequest,
    ) -> Result<String, InferError> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
42
        if self.db.len() > MAX_LENGTH {
            return Err(InferError {});
        }
        let (request_tx, request_rx) = oneshot::channel();
Olivier Dehaene's avatar
Olivier Dehaene committed
43
        self.db.append(input_length, request, request_tx);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
44
45
46
        self.shared.batching_task.notify_waiters();
        match request_rx.await.unwrap() {
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
47
            Err(_) => Err(InferError {}),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48
49
50
51
52
53
54
55
56
        }
    }
}

async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
    loop {
        shared.batching_task.notified().await;

        if let Some(batch) = db.next_batch(32) {
Olivier Dehaene's avatar
Olivier Dehaene committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
            let request_ids = batch.requests.iter().map(|req| req.id).collect();
            let mut cached_batch = match batch.size {
                size if size > 16 => {
                    wrap_future(client.generate_until_finished(batch), request_ids, &db).await
                }
                _ => wrap_future(client.generate(batch), request_ids, &db).await,
            };

            while let Some(batch) = cached_batch {
                let batch_size = batch.size;
                let mut request_ids: Vec<u64> = batch.requests.iter().map(|req| req.id).collect();
                let mut batches = vec![batch];

                if batch_size <= 16 {
                    if let Some(new_batch) = db.next_batch_minimum_size(16, 48) {
                        let new_batch_request_ids =
                            new_batch.requests.iter().map(|req| req.id).collect();
                        let new_cached_batch =
                            wrap_future(client.generate(new_batch), new_batch_request_ids, &db)
                                .await;
                        if let Some(new_cached_batch) = new_cached_batch {
                            request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
                            batches.push(new_cached_batch);
                        }
                    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
                }
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
85

                cached_batch = match batch_size {
                    size if size > 16 => {
Olivier Dehaene's avatar
Olivier Dehaene committed
86
87
88
89
90
91
                        wrap_future(
                            client.generate_until_finished_with_cache(batches),
                            request_ids,
                            &db,
                        )
                        .await
Olivier Dehaene's avatar
Olivier Dehaene committed
92
93
94
                    }
                    _ => wrap_future(client.generate_with_cache(batches), request_ids, &db).await,
                };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
95
96
97
98
99
            }
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
100
101
102
async fn wrap_future(
    future: impl Future<Output = Result<(Vec<GeneratedText>, Option<Batch>), ClientError>>,
    request_ids: Vec<u64>,
Olivier Dehaene's avatar
Olivier Dehaene committed
103
    db: &Db,
Olivier Dehaene's avatar
Olivier Dehaene committed
104
105
106
107
108
) -> Option<Batch> {
    match future.await {
        Ok((generated_texts, next_batch)) => {
            send_generated(generated_texts, db);
            next_batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
109
110
111
        }
        Err(err) => {
            println!("{:?}", err);
Olivier Dehaene's avatar
Olivier Dehaene committed
112
            send_error(err, request_ids, db);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
113
114
115
116
117
118
119
120
121
122
123
124
            None
        }
    }
}

fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
        let (_, response_tx) = db.remove(&id).unwrap();
        response_tx.send(Err(error.clone())).unwrap_or(());
    });
}

Olivier Dehaene's avatar
Olivier Dehaene committed
125
fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
126
    finished.into_iter().for_each(|output| {
Olivier Dehaene's avatar
Olivier Dehaene committed
127
        let (_, response_tx) = db.remove(&output.request.unwrap().id).unwrap();
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
128
129
130
        response_tx.send(Ok(output.output)).unwrap_or(());
    });
}