grpc_client.rs 10.5 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Nicolas Patry's avatar
Nicolas Patry committed
2
3
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
4
5
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
6
use grpc_metadata::InjectTelemetryContext;
OlivierDehaene's avatar
OlivierDehaene committed
7
8
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
9
use std::cmp::min;
10
use std::time::Duration;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
use tonic::transport::{Channel, Uri};
12
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
/// Text Generation Inference gRPC client
15
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
19
20
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
21
    /// Returns a client connected to the given url
Nicolas Patry's avatar
Nicolas Patry committed
22
    #[allow(dead_code)]
Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25

Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
33
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34
35
36
37
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
38
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39

Olivier Dehaene's avatar
Olivier Dehaene committed
40
41
42
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
43
44
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
45
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
48
        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
OlivierDehaene's avatar
OlivierDehaene committed
49
50
51
        let response = self.stub.service_discovery(request).await.map_err(|_| {
            ClientError::Connection("Server does not support v3 interface".to_string())
        })?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
52
53
54
55
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
56
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
57
58
59
60
61
62
63
64
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

65
66
67
68
69
70
71
72
    /// Get model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<InfoResponse> {
        let request = tonic::Request::new(InfoRequest {}).inject_context();
        let response = self.stub.info(request).await?.into_inner();
        Ok(response)
    }

73
74
75
76
77
78
79
80
    /// Get model health
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let request = tonic::Request::new(HealthRequest {}).inject_context();
        let response = self.stub.health(request).await?.into_inner();
        Ok(response)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
    #[instrument(skip(self))]
83
84
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
85
        self.stub.clear_cache(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
86
87
88
        Ok(())
    }

89
90
91
92
93
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
94
95
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
96
97
        let request = tonic::Request::new(FilterBatchRequest {
            batch_id,
98
            request_ids,
99
100
101
102
103
104
        })
        .inject_context();
        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
        Ok(filtered_batch.batch)
    }

105
106
107
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
108
    #[instrument(skip_all)]
109
110
    pub async fn warmup(
        &mut self,
111
        max_input_tokens: Option<u32>,
112
        max_prefill_tokens: u32,
113
        max_total_tokens: Option<u32>,
114
        max_batch_size: Option<usize>,
115
    ) -> Result<(Option<u32>, u32, u32)> {
116
117
118
119
        let mut n_tokens = 0;
        let mut requests = Vec::new();
        // Create requests
        while n_tokens < max_prefill_tokens {
120
121
122
123
            let mut truncate = max_prefill_tokens - n_tokens;
            if let Some(max_input_tokens) = max_input_tokens {
                truncate = min(max_input_tokens, truncate);
            }
124

125
            let mut input_chunks = Vec::new();
126
            input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
127
128
129
130
131
132
133
134
135
136
137
138
139
            if n_tokens == 0 {
                input_chunks.push(
                    Chunk::Image(Image {
                        // Safe unwrap, because we control the data.
                        data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
                        mimetype: "image/jpeg;base64".to_string(),
                    })
                    .into(),
                );
            }

            // Send stringly-typed inputs for compatibility for backends that haven't
            // been updated to support chunks.
OlivierDehaene's avatar
OlivierDehaene committed
140

141
            let mut inputs = String::new();
142
            inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
Nicolas Patry's avatar
Nicolas Patry committed
143
144
145
            if n_tokens == 0 {
                // 1 request is enough to test vision heads.
                // Sending images on other queries messes up easily with truncation.
146
147
148
                inputs.push_str(&format!(
                    "![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
                ));
Nicolas Patry's avatar
Nicolas Patry committed
149
            }
150

151
152
153
154
155
156
            let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
                max_total_tokens - truncate
            } else {
                1
            };

157
158
            requests.push(Request {
                id: 0,
OlivierDehaene's avatar
OlivierDehaene committed
159
                inputs,
160
                add_special_tokens: true,
161
162
163
164
                input_chunks: Some(Input {
                    chunks: input_chunks,
                }),
                // We truncate the input on the server side to be sure that it has the correct size
OlivierDehaene's avatar
OlivierDehaene committed
165
                truncate,
166
167
168
                // Blocks and slots will be set on the server side if we use paged attention
                blocks: vec![],
                slots: vec![],
169
170
                cache_len: 0,
                chunk_len: None,
171
172
173
174
175
176
177
178
179
                // Set sampling parameters to also take these ops into account in the max memory
                parameters: Some(NextTokenChooserParameters {
                    temperature: 0.9,
                    top_k: 10,
                    top_p: 0.9,
                    typical_p: 0.9,
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 1.2,
180
                    frequency_penalty: 0.1,
181
                    watermark: true,
drbh's avatar
drbh committed
182
183
                    grammar: String::new(),
                    grammar_type: GrammarType::None as i32,
184
185
                }),
                stopping_parameters: Some(StoppingCriteriaParameters {
186
                    max_new_tokens,
187
                    stop_sequences: vec![],
OlivierDehaene's avatar
OlivierDehaene committed
188
                    ignore_eos_token: true,
189
190
                }),
                prefill_logprobs: true,
Nicolas Patry's avatar
Nicolas Patry committed
191
                top_n_tokens: 20,
drbh's avatar
drbh committed
192
                adapter_id: None,
193
            });
194
            n_tokens += truncate;
195
196
197
198
199

            // Check max_batch_size
            if Some(requests.len()) == max_batch_size {
                break;
            }
200
201
202
203
204
205
        }

        let batch = Batch {
            id: 0,
            size: requests.len() as u32,
            requests,
206
            max_tokens: max_input_tokens.unwrap_or(0),
207
            max_blocks: 0,
208
209
        };

210
211
        let request = tonic::Request::new(WarmupRequest {
            batch: Some(batch),
212
            max_input_tokens,
213
214
215
216
            max_prefill_tokens,
            max_total_tokens,
        })
        .inject_context();
217
        let response = self.stub.warmup(request).await?.into_inner();
218
219
220
221
222
        Ok((
            response.max_supported_total_tokens,
            response.max_input_tokens,
            response.max_total_tokens,
        ))
223
224
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
225
226
    /// Generate one token for each request in the given batch
    ///
227
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
228
    /// and the next cached batch
229
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
230
231
232
    pub async fn prefill(
        &mut self,
        batch: Batch,
233
        cached_batch: Option<CachedBatch>,
234
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
235
236
237
238
239
        let request = tonic::Request::new(PrefillRequest {
            batch: Some(batch),
            cached_batch,
        })
        .inject_context();
240
        let response = self.stub.prefill(request).await?.into_inner();
241
242
243
        Ok((
            response.generations,
            response.batch,
244
245
246
247
248
249
            PrefillTimings::new(
                response.concat_ns,
                response.forward_ns,
                response.decode_ns,
                response.total_ns,
            ),
250
        ))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
251
252
    }

253
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
254
    ///
255
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
256
    /// and the next cached batch
257
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
258
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
259
        &mut self,
260
        batches: Vec<CachedBatch>,
261
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
262
263
        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
        let response = self.stub.decode(request).await?.into_inner();
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        Ok((
            response.generations,
            response.batch,
            DecodeTimings::new(
                response.concat_ns,
                response.forward_ns,
                response.decode_ns,
                response.total_ns,
            ),
        ))
    }
}

pub struct PrefillTimings {
278
    pub concat: Option<Duration>,
279
280
281
282
283
284
    pub forward: Duration,
    pub decode: Duration,
    pub total: Duration,
}

impl PrefillTimings {
285
    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
286
        Self {
287
            concat: concat_ns.map(Duration::from_nanos),
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            forward: Duration::from_nanos(forward_ns),
            decode: Duration::from_nanos(decode_ns),
            total: Duration::from_nanos(total_ns),
        }
    }
}

pub struct DecodeTimings {
    pub concat: Option<Duration>,
    pub forward: Duration,
    pub decode: Duration,
    pub total: Duration,
}

impl DecodeTimings {
    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
        Self {
305
            concat: concat_ns.map(Duration::from_nanos),
306
307
308
309
            forward: Duration::from_nanos(forward_ns),
            decode: Duration::from_nanos(decode_ns),
            total: Duration::from_nanos(total_ns),
        }
Olivier Dehaene's avatar
Olivier Dehaene committed
310
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
311
}