sharded_client.rs 4.94 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
use crate::Result;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::{Batch, Client, GeneratedText};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
5
6
7
8
9
10
11
use futures::future::join_all;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;

#[derive(Clone, Debug)]
enum Command {
    Generate(
        Batch,
Olivier Dehaene's avatar
Olivier Dehaene committed
12
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
13
14
    ),
    GenerateWithCache(
Olivier Dehaene's avatar
Olivier Dehaene committed
15
16
17
18
19
20
21
22
23
24
        Vec<Batch>,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
    ),
    GenerateUntilFinished(
        Batch,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
    ),
    GenerateUntilFinishedWithCache(
        Vec<Batch>,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
25
26
27
28
29
30
31
32
33
34
35
    ),
    ClearCache(mpsc::Sender<Result<()>>),
}

async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
    while let Ok(message) = request_subscriber.recv().await {
        match message {
            Command::Generate(batch, response_tx) => {
                let result = client.generate(batch).await;
                response_tx.try_send(result).unwrap_or(());
            }
Olivier Dehaene's avatar
Olivier Dehaene committed
36
37
38
39
40
41
42
43
44
45
            Command::GenerateWithCache(batches, response_tx) => {
                let result = client.generate_with_cache(batches).await;
                response_tx.try_send(result).unwrap_or(());
            }
            Command::GenerateUntilFinished(batch, response_tx) => {
                let result = client.generate_until_finished(batch).await;
                response_tx.try_send(result).unwrap_or(());
            }
            Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
                let result = client.generate_until_finished_with_cache(batches).await;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                response_tx.try_send(result).unwrap_or(());
            }
            Command::ClearCache(response_tx) => {
                let result = client.clear_cache().await;
                response_tx.try_send(result).unwrap_or(());
            }
        };
    }
}

pub struct ShardedClient {
    request_tx: broadcast::Sender<Command>,
}

impl ShardedClient {
    fn new(mut clients: Vec<Client>) -> Self {
        let (request_tx, _) = broadcast::channel(1);

        for client in clients.drain(..) {
            let request_subscriber = request_tx.subscribe();
            tokio::spawn(client_task(client, request_subscriber));
        }

        Self { request_tx }
    }

    async fn from_master_client(mut master_client: Client) -> Self {
        let uris = master_client.service_discovery().await.unwrap();
        let futures = uris
            .into_iter()
            .map(|path| Client::connect_uds(path, Duration::from_secs(5)));
        let clients = join_all(futures).await;
        Self::new(clients)
    }

    /// Returns a client connected to the given url. Requests exceeding timeout will fail.
    pub async fn connect(uri: Uri, timeout: Duration) -> Self {
        let master_client = Client::connect(uri, timeout).await;
        Self::from_master_client(master_client).await
    }

    /// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
    pub async fn connect_uds(path: String, timeout: Duration) -> Self {
        let master_client = Client::connect_uds(path, timeout).await;
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
93
    pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
94
95
96
97
98
99
100
101
102
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::Generate(batch, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_with_cache(
        &self,
Olivier Dehaene's avatar
Olivier Dehaene committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::GenerateWithCache(batches, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_until_finished(
        &self,
        batch: Batch,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::GenerateUntilFinished(batch, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_until_finished_with_cache(
        &self,
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
127
128
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
Olivier Dehaene's avatar
Olivier Dehaene committed
129
130
131
132
            .send(Command::GenerateUntilFinishedWithCache(
                batches,
                response_tx,
            ))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
133
134
135
136
137
138
139
140
141
142
143
144
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn clear_cache(&self) -> Result<()> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::ClearCache(response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }
}