client.rs 6.22 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
use crate::pb::generate::v1::*;
use crate::Result;
5
use grpc_metadata::InjectTelemetryContext;
6
use std::cmp::min;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7
use tonic::transport::{Channel, Uri};
8
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
/// Text Generation Inference gRPC client
11
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
13
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
16
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
17
18
19
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
20

Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
25
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
28
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29
30
31
32
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
33
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
34

Olivier Dehaene's avatar
Olivier Dehaene committed
35
36
37
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
41
42
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
43
44
        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
        let response = self.stub.service_discovery(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
50
51
52
53
54
55
56
57
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

58
59
60
61
62
63
64
65
    /// Get model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<InfoResponse> {
        let request = tonic::Request::new(InfoRequest {}).inject_context();
        let response = self.stub.info(request).await?.into_inner();
        Ok(response)
    }

66
67
68
69
70
71
72
73
    /// Get model health
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let request = tonic::Request::new(HealthRequest {}).inject_context();
        let response = self.stub.health(request).await?.into_inner();
        Ok(response)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75
    #[instrument(skip(self))]
76
77
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
78
        self.stub.clear_cache(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
79
80
81
        Ok(())
    }

82
83
84
85
86
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
87
88
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
89
90
        let request = tonic::Request::new(FilterBatchRequest {
            batch_id,
91
            request_ids,
92
93
94
95
96
97
        })
        .inject_context();
        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
        Ok(filtered_batch.batch)
    }

98
99
100
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
101
    #[instrument(skip_all)]
102
103
104
105
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
106
    ) -> Result<Option<u32>> {
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        let mut n_tokens = 0;
        let mut requests = Vec::new();

        // Create requests
        while n_tokens < max_prefill_tokens {
            requests.push(Request {
                id: 0,
                // We truncate the input on the server side to be sure that it has the correct size
                inputs: "_test ".to_string().repeat(max_input_length as usize),
                truncate: min(max_input_length, max_prefill_tokens - n_tokens),
                // Set sampling parameters to also take these ops into account in the max memory
                parameters: Some(NextTokenChooserParameters {
                    temperature: 0.9,
                    top_k: 10,
                    top_p: 0.9,
                    typical_p: 0.9,
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 1.2,
                    watermark: true,
                }),
                stopping_parameters: Some(StoppingCriteriaParameters {
                    max_new_tokens: 2,
                    stop_sequences: vec![],
                    ignore_eos_token: false,
                }),
                prefill_logprobs: true,
Nicolas Patry's avatar
Nicolas Patry committed
134
                top_n_tokens: 20,
135
136
137
138
139
140
141
142
143
144
145
            });
            n_tokens += max_input_length;
        }

        let batch = Batch {
            id: 0,
            size: requests.len() as u32,
            requests,
            max_tokens: 0,
        };

146
147
148
        let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
        let response = self.stub.warmup(request).await?.into_inner();
        Ok(response.max_supported_total_tokens)
149
150
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
151
152
    /// Generate one token for each request in the given batch
    ///
153
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
154
    /// and the next cached batch
155
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
156
157
158
159
    pub async fn prefill(
        &mut self,
        batch: Batch,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
160
161
        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
        let response = self.stub.prefill(request).await?.into_inner();
162
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
163
164
    }

165
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
    ///
167
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
168
    /// and the next cached batch
169
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
170
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
171
        &mut self,
172
173
        batches: Vec<CachedBatch>,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
174
175
        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
        let response = self.stub.decode(request).await?.into_inner();
176
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Olivier Dehaene committed
177
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
178
}