sharded_client.rs 5.62 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Multi shard Client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
use crate::Result;
Olivier Dehaene's avatar
Olivier Dehaene committed
3
use crate::{Batch, Client, GeneratedText};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
4
5
6
7
use futures::future::join_all;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
8
/// List of all available commands that can be sent through the command channel
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11
12
#[derive(Clone, Debug)]
enum Command {
    Generate(
        Batch,
Olivier Dehaene's avatar
Olivier Dehaene committed
13
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
    ),
    GenerateWithCache(
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
        Vec<Batch>,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
    ),
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
19
20
21
    ClearCache(mpsc::Sender<Result<()>>),
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
22
23
24
25
26
27
28
29
30
/// Tokio task that handles the communication with a single shard
///
/// We subscribe on a broadcast channel to receive commands that will be sent by
/// the ShardedClient.
///
/// Each command is fan out to all shards.
///
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
/// producer = the shards, single consumer = the ShardedClient).
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
31
32
33
34
35
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
    while let Ok(message) = request_subscriber.recv().await {
        match message {
            Command::Generate(batch, response_tx) => {
                let result = client.generate(batch).await;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
37
38
                // We can unwrap_or(()) here because the only error that can happen is if the
                // receiver is dropped, which means that the ShardedClient already received a
                // response from another shard
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
39
40
                response_tx.try_send(result).unwrap_or(());
            }
Olivier Dehaene's avatar
Olivier Dehaene committed
41
42
43
44
            Command::GenerateWithCache(batches, response_tx) => {
                let result = client.generate_with_cache(batches).await;
                response_tx.try_send(result).unwrap_or(());
            }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
49
50
51
52
            Command::ClearCache(response_tx) => {
                let result = client.clear_cache().await;
                response_tx.try_send(result).unwrap_or(());
            }
        };
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
/// Text Generation Inference gRPC multi client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
54
pub struct ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
    _clients: Vec<Client>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
56
57
58
59
    request_tx: broadcast::Sender<Command>,
}

impl ShardedClient {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
60
61
62
63
    fn new(clients: Vec<Client>) -> Self {
        // The broadcast channel to communicate with the shards
        // We use a capacity of one as the shards are not asynchronous and can only process one
        // command at a time
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
64
65
        let (request_tx, _) = broadcast::channel(1);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
66
67
        // Spawn client tasks
        for client in clients.iter() {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
68
            let request_subscriber = request_tx.subscribe();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
69
            tokio::spawn(client_task(client.clone(), request_subscriber));
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
70
71
        }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
72
73
74
75
        Self {
            _clients: clients,
            request_tx,
        }
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76
77
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
79
    /// Create a new ShardedClient from a master client. The master client will communicate with
    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
Olivier Dehaene's avatar
Olivier Dehaene committed
80
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
        // Get all uris/unix sockets from the master client
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
82
        let uris = master_client.service_discovery().await.unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
83
        let futures = uris.into_iter().map(Client::connect_uds);
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
86
87
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
    /// Returns a client connected to the given uri
Olivier Dehaene's avatar
Olivier Dehaene committed
89
90
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91
92
93
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
94
95
96
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
97
98
99
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
100
101
102
103
    /// Generate one token for each request in the given batch
    ///
    /// Returns a list of generated texts of request that met their stopping criteria
    /// and the next cached batch
Olivier Dehaene's avatar
Olivier Dehaene committed
104
    pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
105
106
        // Create a channel to receive the response from the shards
        // We will only ever receive one message on this channel
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
107
108
109
110
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::Generate(batch, response_tx))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
111
        // As soon as we receive one response, we can return as all shards will return the same
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
112
113
114
        response_rx.recv().await.unwrap()
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
115
116
117
118
    /// Generate one token for each request in the given cached batch
    ///
    /// Returns a list of generated texts of request that met their stopping criteria
    /// and the next cached batch
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
119
120
    pub async fn generate_with_cache(
        &self,
Olivier Dehaene's avatar
Olivier Dehaene committed
121
122
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
123
124
        // Create a channel to receive the response from the shards
        // We will only ever receive one message on this channel
Olivier Dehaene's avatar
Olivier Dehaene committed
125
126
127
128
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::GenerateWithCache(batches, response_tx))
            .unwrap();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
129
        // As soon as we receive one response, we can return as all shards will return the same
Olivier Dehaene's avatar
Olivier Dehaene committed
130
131
132
        response_rx.recv().await.unwrap()
    }

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
133
    /// Clear the past generations cache
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
134
    pub async fn clear_cache(&self) -> Result<()> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
135
136
        // Create a channel to receive the response from the shards
        // We will only ever receive one message on this channel
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
137
138
139
140
141
142
143
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::ClearCache(response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }
}