client.rs 6.34 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
use crate::pb::generate::v1::*;
use crate::Result;
5
use grpc_metadata::InjectTelemetryContext;
6
use std::cmp::min;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
use tonic::transport::{Channel, Uri};
8
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
/// Text Generation Inference gRPC client
11
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
13
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
16
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
17
18
19
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20

Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
25
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
28
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
31
32
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
33
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34

Olivier Dehaene's avatar
Olivier Dehaene committed
35
36
37
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
43
44
        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
        let response = self.stub.service_discovery(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50
51
52
53
54
55
56
57
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

58
59
60
61
62
63
64
65
    /// Get model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<InfoResponse> {
        let request = tonic::Request::new(InfoRequest {}).inject_context();
        let response = self.stub.info(request).await?.into_inner();
        Ok(response)
    }

66
67
68
69
70
71
72
73
    /// Get model health
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let request = tonic::Request::new(HealthRequest {}).inject_context();
        let response = self.stub.health(request).await?.into_inner();
        Ok(response)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75
    #[instrument(skip(self))]
76
77
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
78
        self.stub.clear_cache(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80
81
        Ok(())
    }

82
83
84
85
86
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
87
88
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
89
90
        let request = tonic::Request::new(FilterBatchRequest {
            batch_id,
91
            request_ids,
92
93
94
95
96
97
        })
        .inject_context();
        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
        Ok(filtered_batch.batch)
    }

98
99
100
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
101
    #[instrument(skip_all)]
102
103
104
105
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
106
        max_total_tokens: u32,
107
    ) -> Result<Option<u32>> {
108
109
        let mut n_tokens = 0;
        let mut requests = Vec::new();
OlivierDehaene's avatar
OlivierDehaene committed
110
        let mut truncate = 0;
111
112
        // Create requests
        while n_tokens < max_prefill_tokens {
OlivierDehaene's avatar
OlivierDehaene committed
113
            truncate = min(max_input_length, max_prefill_tokens - n_tokens);
114
115
116
117
            requests.push(Request {
                id: 0,
                // We truncate the input on the server side to be sure that it has the correct size
                inputs: "_test ".to_string().repeat(max_input_length as usize),
OlivierDehaene's avatar
OlivierDehaene committed
118
                truncate: truncate,
119
120
121
122
123
124
125
126
127
128
129
130
                // Set sampling parameters to also take these ops into account in the max memory
                parameters: Some(NextTokenChooserParameters {
                    temperature: 0.9,
                    top_k: 10,
                    top_p: 0.9,
                    typical_p: 0.9,
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 1.2,
                    watermark: true,
                }),
                stopping_parameters: Some(StoppingCriteriaParameters {
OlivierDehaene's avatar
OlivierDehaene committed
131
                    max_new_tokens: max_total_tokens - truncate,
132
                    stop_sequences: vec![],
OlivierDehaene's avatar
OlivierDehaene committed
133
                    ignore_eos_token: true,
134
135
                }),
                prefill_logprobs: true,
Nicolas Patry's avatar
Nicolas Patry committed
136
                top_n_tokens: 20,
137
138
139
140
141
142
143
144
145
146
147
            });
            n_tokens += max_input_length;
        }

        let batch = Batch {
            id: 0,
            size: requests.len() as u32,
            requests,
            max_tokens: 0,
        };

148
149
150
        let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
        let response = self.stub.warmup(request).await?.into_inner();
        Ok(response.max_supported_total_tokens)
151
152
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
153
154
    /// Generate one token for each request in the given batch
    ///
155
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
156
    /// and the next cached batch
157
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
158
159
160
161
    pub async fn prefill(
        &mut self,
        batch: Batch,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
162
163
        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
        let response = self.stub.prefill(request).await?.into_inner();
164
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
165
166
    }

167
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
168
    ///
169
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
170
    /// and the next cached batch
171
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
172
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
173
        &mut self,
174
175
        batches: Vec<CachedBatch>,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
176
177
        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
        let response = self.stub.decode(request).await?.into_inner();
178
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Olivier Dehaene committed
179
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
180
}