client.rs 4.37 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
3
4
5
6
7
8
9
10
11
use crate::pb::generate::v1::*;
use crate::Result;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tower::timeout::Timeout;
use tracing::*;

/// BLOOM Inference gRPC client
#[derive(Clone)]
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
12
    stub: TextGenerationServiceClient<Timeout<Channel>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13
14
15
16
17
18
19
20
21
22
23
24
}

impl Client {
    /// Returns a client connected to the given url. Requests exceeding timeout will fail.
    pub async fn connect(uri: Uri, timeout: Duration) -> Self {
        let channel = Channel::builder(uri)
            .connect()
            .await
            .expect("Transport error");
        let timeout_channel = Timeout::new(channel, timeout);

        Self {
Olivier Dehaene's avatar
Olivier Dehaene committed
25
            stub: TextGenerationServiceClient::new(timeout_channel),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
26
27
28
29
30
        }
    }

    /// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
    pub async fn connect_uds(path: String, timeout: Duration) -> Self {
Olivier Dehaene's avatar
Olivier Dehaene committed
31
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
33
34
35
36
37
38
39
40
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
            .await
            .expect("Transport error");
        let timeout_channel = Timeout::new(channel, timeout);

        Self {
Olivier Dehaene's avatar
Olivier Dehaene committed
41
            stub: TextGenerationServiceClient::new(timeout_channel),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
42
43
44
45
46
        }
    }

    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
Olivier Dehaene's avatar
Olivier Dehaene committed
47
        let request = tonic::Request::new(ServiceDiscoveryRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        let response = self
            .stub
            .service_discovery(request)
            .instrument(info_span!("service_discovery"))
            .await?;
        let urls = response
            .into_inner()
            .urls
            .into_iter()
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

    #[instrument(skip(self))]
    pub async fn clear_cache(&mut self) -> Result<()> {
Olivier Dehaene's avatar
Olivier Dehaene committed
67
        let request = tonic::Request::new(ClearCacheRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
68
69
70
71
72
73
74
75
        self.stub
            .clear_cache(request)
            .instrument(info_span!("clear_cache"))
            .await?;
        Ok(())
    }

    #[instrument(skip(self))]
Olivier Dehaene's avatar
Olivier Dehaene committed
76
77
    pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
78
79
80
81
82
83
        let response = self
            .stub
            .generate(request)
            .instrument(info_span!("generate"))
            .await?
            .into_inner();
Olivier Dehaene's avatar
Olivier Dehaene committed
84
        Ok((response.generated_texts, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85
86
87
88
89
    }

    #[instrument(skip(self))]
    pub async fn generate_with_cache(
        &mut self,
Olivier Dehaene's avatar
Olivier Dehaene committed
90
91
92
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateWithCacheRequest { batches });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
93
94
95
96
97
98
        let response = self
            .stub
            .generate_with_cache(request)
            .instrument(info_span!("generate_with_cache"))
            .await?
            .into_inner();
Olivier Dehaene's avatar
Olivier Dehaene committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        Ok((response.generated_texts, response.batch))
    }

    #[instrument(skip(self))]
    pub async fn generate_until_finished(
        &mut self,
        batch: Batch,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
        let response = self
            .stub
            .generate_until_finished(request)
            .instrument(info_span!("generate_until_finished"))
            .await?
            .into_inner();
        Ok((response.generated_texts, response.batch))
    }

    #[instrument(skip(self))]
    pub async fn generate_until_finished_with_cache(
        &mut self,
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
        let response = self
            .stub
            .generate_until_finished_with_cache(request)
            .instrument(info_span!("generate_until_finished_with_cache"))
            .await?
            .into_inner();
        Ok((response.generated_texts, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130
131
    }
}