sharded_client.rs 3.02 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Multi shard Client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
use crate::Result;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use crate::{Batch, Client, GeneratedText};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
use futures::future::join_all;
5
use futures::future::select_all;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
6
7
use tonic::transport::Uri;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
8
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
pub struct ShardedClient {
10
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
12
13
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
14
    fn new(clients: Vec<Client>) -> Self {
15
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
16
17
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
18
19
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
20
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
        // Get all uris/unix sockets from the master client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
22
        let uris = master_client.service_discovery().await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
26
27
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31
32
33
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
34
35
36
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
37
38
39
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
41
42
43
    /// Generate one token for each request in the given batch
    ///
    /// Returns a list of generated texts of request that met their stopping criteria
    /// and the next cached batch
44
45
46
47
48
49
    pub async fn generate(&mut self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| Box::pin(client.generate(batch.clone())))
            .collect();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
50
        // As soon as we receive one response, we can return as all shards will return the same
51
52
        let (result, _, _) = select_all(futures).await;
        result
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
53
54
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
56
57
58
    /// Generate one token for each request in the given cached batch
    ///
    /// Returns a list of generated texts of request that met their stopping criteria
    /// and the next cached batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
59
    pub async fn generate_with_cache(
60
        &mut self,
Olivier Dehaene's avatar
Olivier Dehaene committed
61
62
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
63
64
65
66
67
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| Box::pin(client.generate_with_cache(batches.clone())))
            .collect();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
68
        // As soon as we receive one response, we can return as all shards will return the same
69
70
        let (result, _, _) = select_all(futures).await;
        result
Olivier Dehaene's avatar
Olivier Dehaene committed
71
72
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
    /// Clear the past generations cache
74
75
76
77
78
79
80
    pub async fn clear_cache(&mut self) -> Result<()> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.clear_cache())
            .collect();
        join_all(futures).await.into_iter().collect()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
82
    }
}