sharded_client.rs 6.36 KB
Newer Older
1
use crate::client::{DecodeTimings, PrefillTimings};
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
/// Multi shard Client
3
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
4
use crate::{ClientError, Result};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5
6
use futures::future::join_all;
use tonic::transport::Uri;
7
use tracing::instrument;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
8

9
#[derive(Debug, Clone)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
10
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
11
pub struct ShardedClient {
12
    clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13
14
15
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
16
    fn new(clients: Vec<Client>) -> Self {
17
        Self { clients }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
18
19
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
21
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
22
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
        // Get all uris/unix sockets from the master client
24
        let uris = master_client.service_discovery().await?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
28
29
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
33
34
35
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
36
37
38
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
41
        Self::from_master_client(master_client).await
    }

42
43
44
45
46
47
48
49
50
51
52
    /// Get the model info
    #[instrument(skip(self))]
    pub async fn info(&mut self) -> Result<ShardInfo> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.info())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

53
54
55
56
57
58
59
60
61
62
63
    /// GRPC health check
    #[instrument(skip(self))]
    pub async fn health(&mut self) -> Result<HealthResponse> {
        let futures: Vec<_> = self
            .clients
            .iter_mut()
            .map(|client| client.health())
            .collect();
        join_all(futures).await.pop().unwrap()
    }

64
    /// Clear the past generations cache
65
    #[instrument(skip(self))]
66
    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
67
68
69
        let futures: Vec<_> = self
            .clients
            .iter_mut()
70
            .map(|client| client.clear_cache(batch_id))
71
72
73
74
            .collect();
        join_all(futures).await.into_iter().collect()
    }

75
76
77
78
79
    /// Filter a cached batch
    #[instrument(skip(self))]
    pub async fn filter_batch(
        &mut self,
        batch_id: u64,
80
81
        request_ids: Vec<u64>,
    ) -> Result<Option<CachedBatch>> {
82
83
84
        let futures: Vec<_> = self
            .clients
            .iter_mut()
85
            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
86
87
88
89
90
            .collect();
        // all shards return the same message
        join_all(futures).await.pop().unwrap()
    }

91
92
93
94
95
96
97
98
    /// Warmup on a max size batch
    ///
    /// Returns the maximum amount of tokens supported by the hardware
    #[instrument(skip(self))]
    pub async fn warmup(
        &mut self,
        max_input_length: u32,
        max_prefill_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
99
        max_total_tokens: u32,
100
    ) -> Result<Option<u32>> {
101
102
103
        let futures: Vec<_> = self
            .clients
            .iter_mut()
OlivierDehaene's avatar
OlivierDehaene committed
104
105
106
            .map(|client| {
                Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
            })
107
            .collect();
108
109
110
111
112
113
        // Take the minimum value
        let results = join_all(futures)
            .await
            .into_iter()
            .collect::<Result<Vec<Option<u32>>>>()?;
        Ok(results.into_iter().flatten().min())
114
115
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
116
117
    /// Generate one token for each request in the given batch
    ///
118
    /// Returns Generation for each request in batch
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
119
    /// and the next cached batch
120
    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
121
122
123
    pub async fn prefill(
        &mut self,
        batch: Batch,
124
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
125
126
127
        let futures: Vec<_> = self
            .clients
            .iter_mut()
128
            .map(|client| Box::pin(client.prefill(batch.clone())))
129
            .collect();
130
        #[allow(clippy::type_complexity)]
131
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
132
            join_all(futures).await.into_iter().collect();
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        let mut results = results?;

        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;

        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
147
148
    }

149
    /// Generate one token for each request in the given cached batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
150
    ///
151
    /// Returns Generation for each request in batches
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
152
    /// and the next cached batch
153
    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
154
    pub async fn decode(
155
        &mut self,
156
        batches: Vec<CachedBatch>,
157
    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
158
159
160
        let futures: Vec<_> = self
            .clients
            .iter_mut()
161
            .map(|client| Box::pin(client.decode(batches.clone())))
162
            .collect();
163
        #[allow(clippy::type_complexity)]
164
        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
165
            join_all(futures).await.into_iter().collect();
166
        let mut results = results?;
167

168
169
        let (mut generations, next_batch, mut timings) =
            results.pop().ok_or(ClientError::EmptyResults)?;
170

171
172
173
174
175
176
177
178
179
        // Merge generations from different model shards
        for (mut shard_generations, _, shard_timings) in results.into_iter() {
            generations.append(&mut shard_generations);
            // Return the timings of the slowest shard
            if shard_timings.total > timings.total {
                timings = shard_timings;
            }
        }
        Ok((generations, next_batch, timings))
Olivier Dehaene's avatar
Olivier Dehaene committed
180
    }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
181
}