client.rs 3.98 KB
Newer Older
Olivier Dehaene's avatar
Olivier Dehaene committed
1
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
3
4
5
6
7
8
9
use crate::pb::generate::v1::*;
use crate::Result;
use tonic::transport::{Channel, Uri};
use tracing::*;

/// BLOOM Inference gRPC client
#[derive(Clone)]
pub struct Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
10
    stub: TextGenerationServiceClient<Channel>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
12
13
}

impl Client {
Olivier Dehaene's avatar
Olivier Dehaene committed
14
15
16
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let channel = Channel::builder(uri).connect().await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17

Olivier Dehaene's avatar
Olivier Dehaene committed
18
19
20
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
21
22
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
23
24
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
Olivier Dehaene's avatar
Olivier Dehaene committed
25
        let channel = Channel::from_shared("http://[::]:50051".to_string())
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
26
27
28
29
            .unwrap()
            .connect_with_connector(tower::service_fn(move |_: Uri| {
                tokio::net::UnixStream::connect(path.clone())
            }))
Olivier Dehaene's avatar
Olivier Dehaene committed
30
            .await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31

Olivier Dehaene's avatar
Olivier Dehaene committed
32
33
34
        Ok(Self {
            stub: TextGenerationServiceClient::new(channel),
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
35
36
37
38
    }

    #[instrument(skip(self))]
    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
Olivier Dehaene's avatar
Olivier Dehaene committed
39
        let request = tonic::Request::new(ServiceDiscoveryRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        let response = self
            .stub
            .service_discovery(request)
            .instrument(info_span!("service_discovery"))
            .await?;
        let urls = response
            .into_inner()
            .urls
            .into_iter()
            .map(|url| match url.strip_prefix("unix://") {
                None => url,
                Some(stripped_url) => stripped_url.to_string(),
            })
            .collect();
        Ok(urls)
    }

    #[instrument(skip(self))]
    pub async fn clear_cache(&mut self) -> Result<()> {
Olivier Dehaene's avatar
Olivier Dehaene committed
59
        let request = tonic::Request::new(ClearCacheRequest {});
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
60
61
62
63
64
65
66
67
        self.stub
            .clear_cache(request)
            .instrument(info_span!("clear_cache"))
            .await?;
        Ok(())
    }

    #[instrument(skip(self))]
Olivier Dehaene's avatar
Olivier Dehaene committed
68
69
    pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateRequest { batch: Some(batch) });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
70
71
72
73
74
75
        let response = self
            .stub
            .generate(request)
            .instrument(info_span!("generate"))
            .await?
            .into_inner();
Olivier Dehaene's avatar
Olivier Dehaene committed
76
        Ok((response.generated_texts, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
77
78
79
80
81
    }

    #[instrument(skip(self))]
    pub async fn generate_with_cache(
        &mut self,
Olivier Dehaene's avatar
Olivier Dehaene committed
82
83
84
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateWithCacheRequest { batches });
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
85
86
87
88
89
90
        let response = self
            .stub
            .generate_with_cache(request)
            .instrument(info_span!("generate_with_cache"))
            .await?
            .into_inner();
Olivier Dehaene's avatar
Olivier Dehaene committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        Ok((response.generated_texts, response.batch))
    }

    #[instrument(skip(self))]
    pub async fn generate_until_finished(
        &mut self,
        batch: Batch,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateUntilFinishedRequest { batch: Some(batch) });
        let response = self
            .stub
            .generate_until_finished(request)
            .instrument(info_span!("generate_until_finished"))
            .await?
            .into_inner();
        Ok((response.generated_texts, response.batch))
    }

    #[instrument(skip(self))]
    pub async fn generate_until_finished_with_cache(
        &mut self,
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let request = tonic::Request::new(GenerateUntilFinishedWithCacheRequest { batches });
        let response = self
            .stub
            .generate_until_finished_with_cache(request)
            .instrument(info_span!("generate_until_finished_with_cache"))
            .await?
            .into_inner();
        Ok((response.generated_texts, response.batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
122
123
    }
}