batcher.rs 4.4 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
2
3
4
use crate::Db;
use bloom_inference_client::{
    Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient,
};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
use std::sync::Arc;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
7
use tokio::sync::{Notify, oneshot};
use crate::server::GenerateRequest;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8
9
10
11
12
13

const MAX_LENGTH: usize = 128;

pub struct InferError {}

#[derive(Clone)]
Olivier Dehaene's avatar
Olivier Dehaene committed
14
pub(crate) struct Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15
16
17
18
19
20
21
22
    db: Db,
    shared: Arc<Shared>,
}

struct Shared {
    batching_task: Notify,
}

Olivier Dehaene's avatar
Olivier Dehaene committed
23
impl Batcher {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    pub(crate) fn new(client: ShardedClient) -> Self {
        let db = Db::new();
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

        tokio::spawn(batching_task(client, db.clone(), shared.clone()));

        Self { db, shared }
    }

    pub(crate) async fn infer(&self, request: GenerateRequest) -> Result<String, InferError> {
        if self.db.len() > MAX_LENGTH {
            return Err(InferError {});
        }
        let (request_tx, request_rx) = oneshot::channel();
        self.db.append(request, request_tx);
        self.shared.batching_task.notify_waiters();
        match request_rx.await.unwrap() {
            Ok(output) => Ok(output),
Olivier Dehaene's avatar
Olivier Dehaene committed
44
            Err(_) => Err(InferError {}),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        }
    }
}

async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
    loop {
        shared.batching_task.notified().await;

        if let Some(batch) = db.next_batch(32) {
            let mut cache_entry = infer_batch(batch, &client, &db).await;

            loop {
                if let Some(entry) = cache_entry {
                    let mut batch_cached_ids = vec![entry.id];
                    let mut total_batch_size = entry.request_ids.len();
                    let mut max_sequence_length = entry.sequence_length;
                    let mut request_ids = entry.request_ids;

Olivier Dehaene's avatar
Olivier Dehaene committed
63
64
65
66
67
68
69
70
71
72
73
74
75
                    // if total_batch_size <= 16 {
                    //     if let Some(batch) = db.next_batch_minimum_size(16, 48) {
                    //         let other_cache_entry = infer_batch(batch, &client, &db).await;
                    //
                    //         if let Some(entry) = other_cache_entry {
                    //             batch_cached_ids.push(entry.id);
                    //             total_batch_size += entry.request_ids.len();
                    //             max_sequence_length =
                    //                 max_sequence_length.max(entry.sequence_length);
                    //             request_ids.extend(entry.request_ids.into_iter());
                    //         }
                    //     }
                    // }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

                    let batch_cached = BatchCached {
                        id: entry.id,
                        batch_cached_ids,
                        total_batch_size: total_batch_size as u32,
                        max_sequence_length,
                        request_ids,
                    };
                    cache_entry = infer_batch_cached(batch_cached, &client, &db).await;
                } else {
                    break;
                }
            }
        }
    }
}

Olivier Dehaene's avatar
Olivier Dehaene committed
93
94
95
96
97
async fn infer_batch_cached(
    batch: BatchCached,
    client: &ShardedClient,
    db: &Db,
) -> Option<CacheEntry> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    match client.generate_with_cache(batch.clone()).await {
        Ok((finished, cache_entry)) => {
            send_finished(finished, db);
            cache_entry
        }
        Err(err) => {
            println!("{:?}", err);
            send_error(err, batch.request_ids, &db);
            None
        }
    }
}

async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
    match client.generate(batch.clone()).await {
        Ok((finished, cache_entry)) => {
            send_finished(finished, db);
            cache_entry
        }
        Err(err) => {
            println!("{:?}", err);
Olivier Dehaene's avatar
Olivier Dehaene committed
119
120
121
122
123
            send_error(
                err,
                batch.requests.into_iter().map(|req| req.id).collect(),
                &db,
            );
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            None
        }
    }
}

fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
    request_ids.into_iter().for_each(|id| {
        let (_, response_tx) = db.remove(&id).unwrap();
        response_tx.send(Err(error.clone())).unwrap_or(());
    });
}

fn send_finished(finished: Vec<FinishedGeneration>, db: &Db) {
    finished.into_iter().for_each(|output| {
        let (_, response_tx) = db.remove(&output.id).unwrap();
        response_tx.send(Ok(output.output)).unwrap_or(());
    });
}