Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
beb55212
"vscode:/vscode.git/clone" did not exist on "2d753b6fb53a24ffe4e833bd5c29036a36bf091d"
Commit
beb55212
authored
Oct 22, 2022
by
OlivierDehaene
Browse files
feat(client): Simplify sharded logic
parent
c8ce9b25
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
88 deletions
+29
-88
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+26
-85
router/src/batcher.rs
router/src/batcher.rs
+2
-2
router/src/main.rs
router/src/main.rs
+1
-1
No files found.
router/client/src/sharded_client.rs
View file @
beb55212
...
...
@@ -2,76 +2,18 @@
use
crate
::
Result
;
use
crate
::{
Batch
,
Client
,
GeneratedText
};
use
futures
::
future
::
join_all
;
use
tokio
::
sync
::{
broadcast
,
mpsc
}
;
use
futures
::
future
::
select_all
;
use
tonic
::
transport
::
Uri
;
/// List of all available commands that can be sent through the command channel
#[derive(Clone,
Debug)]
enum
Command
{
Generate
(
Batch
,
mpsc
::
Sender
<
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>>
,
),
GenerateWithCache
(
Vec
<
Batch
>
,
mpsc
::
Sender
<
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>>
,
),
ClearCache
(
mpsc
::
Sender
<
Result
<
()
>>
),
}
/// Tokio task that handles the communication with a single shard
///
/// We subscribe on a broadcast channel to receive commands that will be sent by
/// the ShardedClient.
///
/// Each command is fan out to all shards.
///
/// The result of the command is sent back to the ShardedClient through a mpsc channel (multi
/// producer = the shards, single consumer = the ShardedClient).
async
fn
client_task
(
mut
client
:
Client
,
mut
request_subscriber
:
broadcast
::
Receiver
<
Command
>
)
{
while
let
Ok
(
message
)
=
request_subscriber
.recv
()
.await
{
match
message
{
Command
::
Generate
(
batch
,
response_tx
)
=>
{
let
result
=
client
.generate
(
batch
)
.await
;
// We can unwrap_or(()) here because the only error that can happen is if the
// receiver is dropped, which means that the ShardedClient already received a
// response from another shard
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
GenerateWithCache
(
batches
,
response_tx
)
=>
{
let
result
=
client
.generate_with_cache
(
batches
)
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
Command
::
ClearCache
(
response_tx
)
=>
{
let
result
=
client
.clear_cache
()
.await
;
response_tx
.try_send
(
result
)
.unwrap_or
(());
}
};
}
}
/// Text Generation Inference gRPC multi client
pub
struct
ShardedClient
{
_
clients
:
Vec
<
Client
>
,
request_tx
:
broadcast
::
Sender
<
Command
>
,
clients
:
Vec
<
Client
>
,
}
impl
ShardedClient
{
fn
new
(
clients
:
Vec
<
Client
>
)
->
Self
{
// The broadcast channel to communicate with the shards
// We use a capacity of one as the shards are not asynchronous and can only process one
// command at a time
let
(
request_tx
,
_
)
=
broadcast
::
channel
(
1
);
// Spawn client tasks
for
client
in
clients
.iter
()
{
let
request_subscriber
=
request_tx
.subscribe
();
tokio
::
spawn
(
client_task
(
client
.clone
(),
request_subscriber
));
}
Self
{
_
clients
:
clients
,
request_tx
,
clients
,
}
}
...
...
@@ -101,15 +43,15 @@ impl ShardedClient {
///
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub
async
fn
generate
(
&
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
Generate
(
batch
,
response_tx
))
.unwrap
();
pub
async
fn
generate
(
&
mut
self
,
batch
:
Batch
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.generate
(
batch
.clone
())))
.collect
();
// As soon as we receive one response, we can return as all shards will return the same
response_rx
.recv
()
.await
.unwrap
()
let
(
result
,
_
,
_
)
=
select_all
(
futures
)
.await
;
result
}
/// Generate one token for each request in the given cached batch
...
...
@@ -117,27 +59,26 @@ impl ShardedClient {
/// Returns a list of generated texts of request that met their stopping criteria
/// and the next cached batch
pub
async
fn
generate_with_cache
(
&
self
,
&
mut
self
,
batches
:
Vec
<
Batch
>
,
)
->
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
)
>
{
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
GenerateWithCache
(
batches
,
response_tx
))
.unwrap
();
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.generate_with_cache
(
batches
.clone
())))
.collect
();
// As soon as we receive one response, we can return as all shards will return the same
response_rx
.recv
()
.await
.unwrap
()
let
(
result
,
_
,
_
)
=
select_all
(
futures
)
.await
;
result
}
/// Clear the past generations cache
pub
async
fn
clear_cache
(
&
self
)
->
Result
<
()
>
{
// Create a channel to receive the response from the shards
// We will only ever receive one message on this channel
let
(
response_tx
,
mut
response_rx
)
=
mpsc
::
channel
(
1
);
self
.request_tx
.send
(
Command
::
ClearCache
(
response_tx
))
.unwrap
();
response_rx
.recv
()
.await
.unwrap
()
pub
async
fn
clear_cache
(
&
mut
self
)
->
Result
<
()
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
client
.clear_cache
())
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
()
}
}
router/src/batcher.rs
View file @
beb55212
...
...
@@ -39,9 +39,9 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic
tokio
::
spawn
(
batching_task
(
client
,
max_batch_size
,
max_waiting_tokens
,
client
,
db
.clone
(),
shared
.clone
(),
));
...
...
@@ -86,9 +86,9 @@ impl Batcher {
/// Batches requests and sends them to the inference server
#[instrument(skip(client,
db,
shared))]
async
fn
batching_task
(
mut
client
:
ShardedClient
,
max_batch_size
:
usize
,
max_waiting_tokens
:
usize
,
client
:
ShardedClient
,
db
:
Db
,
shared
:
Arc
<
Shared
>
,
)
{
...
...
router/src/main.rs
View file @
beb55212
...
...
@@ -61,7 +61,7 @@ fn main() -> Result<(), std::io::Error> {
.unwrap
()
.block_on
(
async
{
// Instantiate sharded client from the master unix socket
let
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
let
mut
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
.await
.expect
(
"Could not connect to server"
);
// Clear the cache; useful if the webserver rebooted
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment