sharded_client.rs 9.09 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Multi shard Client
OlivierDehaene's avatar
OlivierDehaene committed
2
use crate::{v3, Health, ShardInfo};
3
use crate::{ClientError, Result};
OlivierDehaene's avatar
OlivierDehaene committed
4
5
6

use crate::v3::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
8
use futures::future::join_all;
use tonic::transport::Uri;
9
use tracing::instrument;
OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
13
14
use v3::client::{DecodeTimings, PrefillTimings};
use v3::{
    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
    NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
15

16
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
pub struct ShardedClient {
19
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20
21
22
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    fn new(clients: Vec<Client>) -> Self {
24
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
28
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
29
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
        // Get all uris/unix sockets from the master client
31
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
37
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
38
39
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
45
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
48
        Self::from_master_client(master_client).await
    }

49
50
51
52
53
54
55
56
    /// Get the model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<ShardInfo> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
OlivierDehaene's avatar
OlivierDehaene committed
57
        join_all(futures).await.pop().unwrap().map(ShardInfo::from)
58
59
    }

60
61
62
63
64
65
66
67
68
69
70
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

71
    /// Clear the past generations cache
72
    #[instrument(skip(self))]
73
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
74
75
76
        let futures: Vec<_> = self
            .clients
            .iter_mut()
77
            .map(|client| client.clear_cache(batch_id))
78
79
80
81
            .collect();
        join_all(futures).await.into_iter().collect()
    }

82
83
84
85
86
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
87
88
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
89
90
91
        let futures: Vec<_> = self
            .clients
            .iter_mut()
92
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
93
94
95
96
97
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

98
99
100
101
102
103
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
104
        max_input_length: Option<u32>,
105
        max_prefill_tokens: u32,
106
        max_total_tokens: Option<u32>,
107
        max_batch_size: Option<usize>,
108
    ) -> Result<(Option<u32>, u32, u32)> {
109
110
111
        let futures: Vec<_> = self
            .clients
            .iter_mut()
OlivierDehaene's avatar
OlivierDehaene committed
112
            .map(|client| {
113
114
115
116
117
118
                Box::pin(client.warmup(
                    max_input_length,
                    max_prefill_tokens,
                    max_total_tokens,
                    max_batch_size,
                ))
OlivierDehaene's avatar
OlivierDehaene committed
119
            })
120
            .collect();
121
122
123
124
        // Take the minimum value
        let results = join_all(futures)
            .await
            .into_iter()
125
126
127
128
129
130
131
132
133
134
            .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

        // Take the minimum value
        // Different shards hold different parts of vocab, might yield
        // different available block size.
        let min = results
            .iter()
            .min()
            .expect("Expect at least 1 warmup result");
        Ok(*min)
135
136
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
137
138
    /// Generate one token for each request in the given batch
    ///
139
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
140
    /// and the next cached batch
141
    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
142
143
144
    pub async fn prefill(
        &mut self,
        batch: Batch,
145
        cached_batch: Option<CachedBatch>,
146
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
147
148
149
        let futures: Vec<_> = self
            .clients
            .iter_mut()
150
            .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
151
            .collect();
152
        #[allow(clippy::type_complexity)]
153
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
154
            join_all(futures).await.into_iter().collect();
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        let mut results = results?;

        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;

        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169
170
    }

171
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
172
    ///
173
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
174
    /// and the next cached batch
175
    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
176
    pub async fn decode(
177
        &mut self,
178
        batches: Vec<CachedBatch>,
179
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
180
181
182
        let futures: Vec<_> = self
            .clients
            .iter_mut()
183
            .map(|client| Box::pin(client.decode(batches.clone())))
184
            .collect();
185
        #[allow(clippy::type_complexity)]
186
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
187
            join_all(futures).await.into_iter().collect();
188
        let mut results = results?;
189

190
191
        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;
192

193
194
195
196
197
198
199
200
201
        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Olivier Dehaene committed
202
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
203
}
OlivierDehaene's avatar
OlivierDehaene committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

impl From<InfoResponse> for ShardInfo {
    fn from(value: InfoResponse) -> Self {
        Self {
            requires_padding: value.requires_padding,
            dtype: value.dtype,
            device_type: value.device_type,
            window_size: value.window_size,
            speculate: value.speculate,
        }
    }
}

#[async_trait]
impl Health for ShardedClient {
    async fn device_health(&self) -> Result<()> {
        self.clone().health().await?;
        Ok(())
    }

    async fn model_health(&self) -> Result<()> {
        // Dummy batch of 1 token and 1 generated token
        let liveness_request = Request {
            id: u64::MAX,
            inputs: "liveness".to_string(),
            input_chunks: Some(Input {
                chunks: vec![Chunk::Text("liveness".into()).into()],
            }),
            truncate: 10,
233
            add_special_tokens: true,
OlivierDehaene's avatar
OlivierDehaene committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            prefill_logprobs: false,
            parameters: Some(NextTokenChooserParameters {
                temperature: 1.0,
                top_k: 0,
                top_p: 1.0,
                typical_p: 1.0,
                do_sample: false,
                seed: 0,
                repetition_penalty: 1.0,
                frequency_penalty: 0.0,
                watermark: false,
                grammar: String::new(),
                grammar_type: GrammarType::None as i32,
            }),
            stopping_parameters: Some(StoppingCriteriaParameters {
                max_new_tokens: 1,
                stop_sequences: vec![],
                ignore_eos_token: false,
            }),
            top_n_tokens: 0,
254
255
256
            // Block 0 is reserved for health checks
            blocks: vec![0],
            slots: (0..16).collect(),
257
258
            cache_len: 0,
            chunk_len: None,
drbh's avatar
drbh committed
259
            adapter_id: None,
OlivierDehaene's avatar
OlivierDehaene committed
260
261
262
263
264
265
        };
        let batch = Batch {
            id: u64::MAX,
            requests: vec![liveness_request],
            size: 1,
            max_tokens: 2,
266
            max_blocks: 1,
OlivierDehaene's avatar
OlivierDehaene committed
267
        };
268
        self.clone().prefill(batch, None).await?;
OlivierDehaene's avatar
OlivierDehaene committed
269
270
271
        Ok(())
    }
}