sharded_client.rs 8.72 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Multi shard Client
OlivierDehaene's avatar
OlivierDehaene committed
2
use crate::{v3, Health, ShardInfo};
3
use crate::{ClientError, Result};
OlivierDehaene's avatar
OlivierDehaene committed
4
5
6

use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
use futures::future::join_all;
use tonic::transport::Uri;
9
use tracing::instrument;
OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
13
14
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
    NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15

16
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
pub struct ShardedClient {
19
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    fn new(clients: Vec<Client>) -> Self {
24
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
28
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
29
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
        // Get all uris/unix sockets from the master client
31
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
38
39
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
45
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
48
        Self::from_master_client(master_client).await
    }

49
50
51
52
53
54
55
56
    /// Get the model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<ShardInfo> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
OlivierDehaene's avatar
OlivierDehaene committed
57
        join_all(futures).await.pop().unwrap().map(ShardInfo::from)
58
59
    }

60
61
62
63
64
65
66
67
68
69
70
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

71
    /// Clear the past generations cache
72
    #[instrument(skip(self))]
73
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
74
75
76
        let futures: Vec<_> = self
            .clients
            .iter_mut()
77
            .map(|client| client.clear_cache(batch_id))
78
79
80
81
            .collect();
        join_all(futures).await.into_iter().collect()
    }

82
83
84
85
86
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
87
88
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
89
90
91
        let futures: Vec<_> = self
            .clients
            .iter_mut()
92
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
93
94
95
96
97
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

98
99
100
101
102
103
104
105
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
106
        max_total_tokens: u32,
107
        max_batch_size: Option<usize>,
108
    ) -> Result<Option<u32>> {
109
110
111
        let futures: Vec<_> = self
            .clients
            .iter_mut()
OlivierDehaene's avatar
OlivierDehaene committed
112
            .map(|client| {
113
114
115
116
117
118
                Box::pin(client.warmup(
                    max_input_length,
                    max_prefill_tokens,
                    max_total_tokens,
                    max_batch_size,
                ))
OlivierDehaene's avatar
OlivierDehaene committed
119
            })
120
            .collect();
121
122
123
124
125
126
        // Take the minimum value
        let results = join_all(futures)
            .await
            .into_iter()
            .collect::<Result<Vec<Option<u32>>>>()?;
        Ok(results.into_iter().flatten().min())
127
128
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
129
130
    /// Generate one token for each request in the given batch
    ///
131
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
    /// and the next cached batch
133
    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
134
135
136
    pub async fn prefill(
        &mut self,
        batch: Batch,
137
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
138
139
140
        let futures: Vec<_> = self
            .clients
            .iter_mut()
141
            .map(|client| Box::pin(client.prefill(batch.clone())))
142
            .collect();
143
        #[allow(clippy::type_complexity)]
144
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
145
            join_all(futures).await.into_iter().collect();
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        let mut results = results?;

        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;

        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
160
161
    }

162
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
163
    ///
164
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
165
    /// and the next cached batch
166
    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
167
    pub async fn decode(
168
        &mut self,
169
        batches: Vec<CachedBatch>,
170
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
171
172
173
        let futures: Vec<_> = self
            .clients
            .iter_mut()
174
            .map(|client| Box::pin(client.decode(batches.clone())))
175
            .collect();
176
        #[allow(clippy::type_complexity)]
177
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
178
            join_all(futures).await.into_iter().collect();
179
        let mut results = results?;
180

181
182
        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;
183

184
185
186
187
188
189
190
191
192
        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Olivier Dehaene committed
193
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
194
}
OlivierDehaene's avatar
OlivierDehaene committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

impl From<InfoResponse> for ShardInfo {
    fn from(value: InfoResponse) -> Self {
        Self {
            requires_padding: value.requires_padding,
            dtype: value.dtype,
            device_type: value.device_type,
            window_size: value.window_size,
            speculate: value.speculate,
        }
    }
}

#[async_trait]
impl Health for ShardedClient {
    async fn device_health(&self) -> Result<()> {
        self.clone().health().await?;
        Ok(())
    }

    async fn model_health(&self) -> Result<()> {
        // Dummy batch of 1 token and 1 generated token
        let liveness_request = Request {
            id: u64::MAX,
            inputs: "liveness".to_string(),
            input_chunks: Some(Input {
                chunks: vec![Chunk::Text("liveness".into()).into()],
            }),
            truncate: 10,
224
            add_special_tokens: true,
OlivierDehaene's avatar
OlivierDehaene committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
            prefill_logprobs: false,
            parameters: Some(NextTokenChooserParameters {
                temperature: 1.0,
                top_k: 0,
                top_p: 1.0,
                typical_p: 1.0,
                do_sample: false,
                seed: 0,
                repetition_penalty: 1.0,
                frequency_penalty: 0.0,
                watermark: false,
                grammar: String::new(),
                grammar_type: GrammarType::None as i32,
            }),
            stopping_parameters: Some(StoppingCriteriaParameters {
                max_new_tokens: 1,
                stop_sequences: vec![],
                ignore_eos_token: false,
            }),
            top_n_tokens: 0,
245
246
247
            // Block 0 is reserved for health checks
            blocks: vec![0],
            slots: (0..16).collect(),
248
            prefix_len: 0,
drbh's avatar
drbh committed
249
            adapter_id: None,
OlivierDehaene's avatar
OlivierDehaene committed
250
251
252
253
254
255
        };
        let batch = Batch {
            id: u64::MAX,
            requests: vec![liveness_request],
            size: 1,
            max_tokens: 2,
256
            max_blocks: 1,
OlivierDehaene's avatar
OlivierDehaene committed
257
258
259
260
261
        };
        self.clone().prefill(batch).await?;
        Ok(())
    }
}