sharded_client.rs 4.81 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
use crate::Result;
Olivier Dehaene's avatar
Olivier Dehaene committed
2
use crate::{Batch, Client, GeneratedText};
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
3
4
5
6
7
8
9
10
use futures::future::join_all;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;

#[derive(Clone, Debug)]
enum Command {
    Generate(
        Batch,
Olivier Dehaene's avatar
Olivier Dehaene committed
11
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
12
13
    ),
    GenerateWithCache(
Olivier Dehaene's avatar
Olivier Dehaene committed
14
15
16
17
18
19
20
21
22
23
        Vec<Batch>,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
    ),
    GenerateUntilFinished(
        Batch,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
    ),
    GenerateUntilFinishedWithCache(
        Vec<Batch>,
        mpsc::Sender<Result<(Vec<GeneratedText>, Option<Batch>)>>,
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
24
25
26
27
28
29
30
31
32
33
34
    ),
    ClearCache(mpsc::Sender<Result<()>>),
}

async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
    while let Ok(message) = request_subscriber.recv().await {
        match message {
            Command::Generate(batch, response_tx) => {
                let result = client.generate(batch).await;
                response_tx.try_send(result).unwrap_or(());
            }
Olivier Dehaene's avatar
Olivier Dehaene committed
35
36
37
38
39
40
41
42
43
44
            Command::GenerateWithCache(batches, response_tx) => {
                let result = client.generate_with_cache(batches).await;
                response_tx.try_send(result).unwrap_or(());
            }
            Command::GenerateUntilFinished(batch, response_tx) => {
                let result = client.generate_until_finished(batch).await;
                response_tx.try_send(result).unwrap_or(());
            }
            Command::GenerateUntilFinishedWithCache(batches, response_tx) => {
                let result = client.generate_until_finished_with_cache(batches).await;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                response_tx.try_send(result).unwrap_or(());
            }
            Command::ClearCache(response_tx) => {
                let result = client.clear_cache().await;
                response_tx.try_send(result).unwrap_or(());
            }
        };
    }
}

pub struct ShardedClient {
    request_tx: broadcast::Sender<Command>,
}

impl ShardedClient {
    fn new(mut clients: Vec<Client>) -> Self {
        let (request_tx, _) = broadcast::channel(1);

        for client in clients.drain(..) {
            let request_subscriber = request_tx.subscribe();
            tokio::spawn(client_task(client, request_subscriber));
        }

        Self { request_tx }
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
71
    async fn from_master_client(mut master_client: Client) -> Result<Self> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
72
        let uris = master_client.service_discovery().await.unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
73
74
75
        let futures = uris.into_iter().map(|path| Client::connect_uds(path));
        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
        Ok(Self::new(clients?))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76
77
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
78
79
80
    /// Returns a client connected to the given url
    pub async fn connect(uri: Uri) -> Result<Self> {
        let master_client = Client::connect(uri).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
81
82
83
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
86
    /// Returns a client connected to the given unix socket
    pub async fn connect_uds(path: String) -> Result<Self> {
        let master_client = Client::connect_uds(path).await?;
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
87
88
89
        Self::from_master_client(master_client).await
    }

Olivier Dehaene's avatar
Olivier Dehaene committed
90
    pub async fn generate(&self, batch: Batch) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91
92
93
94
95
96
97
98
99
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::Generate(batch, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_with_cache(
        &self,
Olivier Dehaene's avatar
Olivier Dehaene committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::GenerateWithCache(batches, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_until_finished(
        &self,
        batch: Batch,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::GenerateUntilFinished(batch, response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn generate_until_finished_with_cache(
        &self,
        batches: Vec<Batch>,
    ) -> Result<(Vec<GeneratedText>, Option<Batch>)> {
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
124
125
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
128
129
            .send(Command::GenerateUntilFinishedWithCache(
                batches,
                response_tx,
            ))
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
130
131
132
133
134
135
136
137
138
139
140
141
            .unwrap();
        response_rx.recv().await.unwrap()
    }

    pub async fn clear_cache(&self) -> Result<()> {
        let (response_tx, mut response_rx) = mpsc::channel(1);
        self.request_tx
            .send(Command::ClearCache(response_tx))
            .unwrap();
        response_rx.recv().await.unwrap()
    }
}