client.rs 3.34 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Single shard Client
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
5
6
7
use crate::pb::generate::v1::*;
use crate::Result;
use tonic::transport::{Channel, Uri};
use tracing::*;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
8
/// Text Generation Inference gRPC client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
#[derive(Clone)]
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
11
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
13
14
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
15
16
17
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18

Olivier Dehaene's avatar
Olivier Dehaene committed
19
20
21
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
22
23
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
26
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28
29
30
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
31
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32

Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
35
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
36
37
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
    /// Returns a list of uris or unix sockets of all shards
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
Olivier Dehaene's avatar
Olivier Dehaene committed
41
        let request = tonic::Request::new(ServiceDiscoveryRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
42
43
44
45
46
47
48
49
50
        let response = self
            .stub
            .service_discovery(request)
            .instrument(info_span!("service_discovery"))
            .await?;
        let urls = response
            .into_inner()
            .urls
            .into_iter()
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
51
            // Remove unix socket prefix
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
52
53
54
55
56
57
58
59
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
61
62
    #[instrument(skip(self))]
    pub async fn clear_cache(&mut self) -> Result<()> {
Olivier Dehaene's avatar
Olivier Dehaene committed
63
        let request = tonic::Request::new(ClearCacheRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
64
65
66
67
68
69
70
        self.stub
            .clear_cache(request)
            .instrument(info_span!("clear_cache"))
            .await?;
        Ok(())
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
71
72
    /// Generate one token for each request in the given batch
    ///
73
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
    /// and the next cached batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
75
    #[instrument(skip(self))]
76
77
    pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79
        let response = self
            .stub
80
81
            .prefill(request)
            .instrument(info_span!("prefill"))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
83
            .await?
            .into_inner();
84
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85
86
    }

87
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
    ///
89
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
90
    /// and the next cached batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91
    #[instrument(skip(self))]
92
    pub async fn decode(
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
        &mut self,
Olivier Dehaene's avatar
Olivier Dehaene committed
94
        batches: Vec<Batch>,
95
96
    ) -> Result<(Vec<Generation>, Option<Batch>)> {
        let request = tonic::Request::new(DecodeRequest { batches });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
97
98
        let response = self
            .stub
99
100
            .decode(request)
            .instrument(info_span!("decode"))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
101
102
            .await?
            .into_inner();
103
        Ok((response.generations, response.batch))
Olivier Dehaene's avatar
Olivier Dehaene committed
104
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
105
}