client.rs 4.27 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
use crate::pb::generate::v1::*;
use crate::Result;
5
use grpc_metadata::InjectTelemetryContext;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
use tonic::transport::{Channel, Uri};
7
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
/// Text Generation Inference gRPC client
10
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
12
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13
14
15
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19

Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
22
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
23
24
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
25
26
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
27
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
28
29
30
31
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
32
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33

Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
38
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
42
43
        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
        let response = self.stub.service_discovery(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
44
45
46
47
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
49
50
51
52
53
54
55
56
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

57
58
59
60
61
62
63
64
    /// Get model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<InfoResponse> {
        let request = tonic::Request::new(InfoRequest {}).inject_context();
        let response = self.stub.info(request).await?.into_inner();
        Ok(response)
    }

65
66
67
68
69
70
71
72
    /// Get model health
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let request = tonic::Request::new(HealthRequest {}).inject_context();
        let response = self.stub.health(request).await?.into_inner();
        Ok(response)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
74
    #[instrument(skip(self))]
75
76
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
77
        self.stub.clear_cache(request).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79
80
        Ok(())
    }

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
        keep_requests: Vec<Request>,
    ) -> Result<Option<Batch>> {
        let request = tonic::Request::new(FilterBatchRequest {
            batch_id,
            keep_requests,
        })
        .inject_context();
        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
        Ok(filtered_batch.batch)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
97
98
    /// Generate one token for each request in the given batch
    ///
99
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
    /// and the next cached batch
101
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
102
    pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
103
104
        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
        let response = self.stub.prefill(request).await?.into_inner();
105
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
106
107
    }

108
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
109
    ///
110
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
    /// and the next cached batch
112
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
113
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
114
        &mut self,
Olivier Dehaene's avatar
Olivier Dehaene committed
115
        batches: Vec<Batch>,
116
    ) -> Result<(Vec<Generation>, Option<Batch>)> {
117
118
        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
        let response = self.stub.decode(request).await?.into_inner();
119
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Olivier Dehaene committed
120
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
121
}