sharded_client.rs 5.34 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Multi shard Client
2
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
3
use crate::{ClientError, Result};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
use futures::future::join_all;
use tonic::transport::Uri;
6
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
7

8
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
9
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
10
pub struct ShardedClient {
11
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
13
14
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
15
    fn new(clients: Vec<Client>) -> Self {
16
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
17
18
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
19
20
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
21
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
        // Get all uris/unix sockets from the master client
23
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
24
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
25
26
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
27
28
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
29
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
30
31
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
32
33
34
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
35
36
37
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
38
39
40
        Self::from_master_client(master_client).await
    }

41
42
43
44
45
46
47
48
49
50
51
    /// Get the model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<ShardInfo> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

52
53
54
55
56
57
58
59
60
61
62
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

63
    /// Clear the past generations cache
64
    #[instrument(skip(self))]
65
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
66
67
68
        let futures: Vec<_> = self
            .clients
            .iter_mut()
69
            .map(|client| client.clear_cache(batch_id))
70
71
72
73
            .collect();
        join_all(futures).await.into_iter().collect()
    }

74
75
76
77
78
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
79
80
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
81
82
83
        let futures: Vec<_> = self
            .clients
            .iter_mut()
84
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
85
86
87
88
89
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

90
91
92
93
94
95
96
97
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
98
    ) -> Result<Option<u32>> {
99
100
101
        let futures: Vec<_> = self
            .clients
            .iter_mut()
102
            .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
103
104
105
106
107
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
108
109
    /// Generate one token for each request in the given batch
    ///
110
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
    /// and the next cached batch
112
    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
113
114
115
116
    pub async fn prefill(
        &mut self,
        batch: Batch,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
117
118
119
        let futures: Vec<_> = self
            .clients
            .iter_mut()
120
            .map(|client| Box::pin(client.prefill(batch.clone())))
121
            .collect();
122
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
123
124
            join_all(futures).await.into_iter().collect();
        merge_generations(results?)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
125
126
    }

127
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
128
    ///
129
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
130
    /// and the next cached batch
131
    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
132
    pub async fn decode(
133
        &mut self,
134
135
        batches: Vec<CachedBatch>,
    ) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
136
137
138
        let futures: Vec<_> = self
            .clients
            .iter_mut()
139
            .map(|client| Box::pin(client.decode(batches.clone())))
140
            .collect();
141
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
142
143
144
145
146
147
148
            join_all(futures).await.into_iter().collect();
        merge_generations(results?)
    }
}

/// Merge generations from the different model shards
fn merge_generations(
149
150
    mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
151
152
153
154
    let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;

    for (mut shard_generations, _) in results.into_iter() {
        generations.append(&mut shard_generations);
Olivier Dehaene's avatar
Olivier Dehaene committed
155
    }
156
    Ok((generations, next_batch))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
157
}