Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1539d3cb
Unverified
Commit
1539d3cb
authored
Jan 26, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 26, 2023
Browse files
feat(router): Remove second lock from batcher hot path (#27)
@njhill
parent
ce960be0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
34 deletions
+39
-34
Cargo.lock
Cargo.lock
+7
-0
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/batcher.rs
router/src/batcher.rs
+17
-21
router/src/db.rs
router/src/db.rs
+14
-13
No files found.
Cargo.lock
View file @
1539d3cb
...
@@ -1087,6 +1087,12 @@ dependencies = [
...
@@ -1087,6 +1087,12 @@ dependencies = [
"libc",
"libc",
]
]
[[package]]
name = "nohash-hasher"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bf50223579dc7cdcfb3bfcacf7069ff68243f8c363f62ffa99cf000a6b9c451"
[[package]]
[[package]]
name = "nom"
name = "nom"
version = "7.1.1"
version = "7.1.1"
...
@@ -1826,6 +1832,7 @@ dependencies = [
...
@@ -1826,6 +1832,7 @@ dependencies = [
"axum",
"axum",
"clap 4.0.22",
"clap 4.0.22",
"futures",
"futures",
"nohash-hasher",
"parking_lot",
"parking_lot",
"serde",
"serde",
"serde_json",
"serde_json",
...
...
router/Cargo.toml
View file @
1539d3cb
...
@@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] }
...
@@ -17,6 +17,7 @@ axum = { version = "0.5.16", features = ["json", "serde_json"] }
text-generation-client
=
{
path
=
"client"
}
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.0.15"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.24"
futures
=
"0.3.24"
nohash-hasher
=
"0.2.0"
parking_lot
=
"0.12.1"
parking_lot
=
"0.12.1"
serde
=
"1.0.145"
serde
=
"1.0.145"
serde_json
=
"1.0.85"
serde_json
=
"1.0.85"
...
...
router/src/batcher.rs
View file @
1539d3cb
...
@@ -3,6 +3,7 @@ use crate::{Db, Entry};
...
@@ -3,6 +3,7 @@ use crate::{Db, Entry};
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
crate
::{
ErrorResponse
,
GenerateRequest
};
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::
StatusCode
;
use
axum
::
Json
;
use
axum
::
Json
;
use
nohash_hasher
::
IntMap
;
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
...
@@ -104,8 +105,8 @@ async fn batching_task(
...
@@ -104,8 +105,8 @@ async fn batching_task(
// Get the next batch from the DB
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
// waiting in the DB
while
let
Some
((
request_id
s
,
batch
))
=
db
.next_batch
(
None
,
max_batch_size
)
{
while
let
Some
((
mut
entrie
s
,
batch
))
=
db
.next_batch
(
None
,
max_batch_size
)
{
let
mut
cached_batch
=
wrap_future
(
client
.generate
(
batch
),
request_ids
,
&
db
)
.await
;
let
mut
cached_batch
=
wrap_future
(
client
.generate
(
batch
),
&
mut
entries
)
.await
;
let
mut
waiting_tokens
=
1
;
let
mut
waiting_tokens
=
1
;
// We loop until we do not receive any cached batch from the inference server (== until
// We loop until we do not receive any cached batch from the inference server (== until
...
@@ -113,7 +114,6 @@ async fn batching_task(
...
@@ -113,7 +114,6 @@ async fn batching_task(
while
let
Some
(
batch
)
=
cached_batch
{
while
let
Some
(
batch
)
=
cached_batch
{
// Get current batch info
// Get current batch info
let
batch_size
=
batch
.size
;
let
batch_size
=
batch
.size
;
let
mut
request_ids
:
Vec
<
u64
>
=
batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
.collect
();
let
mut
batches
=
vec!
[
batch
];
let
mut
batches
=
vec!
[
batch
];
// If the current batch is too small, we try to add more requests to it
// If the current batch is too small, we try to add more requests to it
...
@@ -127,24 +127,23 @@ async fn batching_task(
...
@@ -127,24 +127,23 @@ async fn batching_task(
};
};
// Try to get a new batch
// Try to get a new batch
if
let
Some
((
new_request_id
s
,
new_batch
))
=
if
let
Some
((
mut
new_entrie
s
,
new_batch
))
=
db
.next_batch
(
min_size
,
max_batch_size
-
batch_size
as
usize
)
db
.next_batch
(
min_size
,
max_batch_size
-
batch_size
as
usize
)
{
{
// Generate one token for this new batch to have the attention past in cache
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
let
new_cached_batch
=
wrap_future
(
client
.generate
(
new_batch
),
new_request_ids
,
&
db
)
.await
;
wrap_future
(
client
.generate
(
new_batch
),
&
mut
new_entries
)
.await
;
// Reset waiting counter
// Reset waiting counter
waiting_tokens
=
1
;
waiting_tokens
=
1
;
// Extend current batch with the new batch
// Extend current batch with the new batch
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
request_id
s
.extend
(
new_
cached_batch
.requests
.iter
()
.map
(|
req
|
req
.id
)
);
entrie
s
.extend
(
new_
entries
);
batches
.push
(
new_cached_batch
);
batches
.push
(
new_cached_batch
);
}
}
}
}
}
}
cached_batch
=
cached_batch
=
wrap_future
(
client
.generate_with_cache
(
batches
),
&
mut
entries
)
.await
;
wrap_future
(
client
.generate_with_cache
(
batches
),
request_ids
,
&
db
)
.await
;
waiting_tokens
+=
1
;
waiting_tokens
+=
1
;
}
}
}
}
...
@@ -154,39 +153,36 @@ async fn batching_task(
...
@@ -154,39 +153,36 @@ async fn batching_task(
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
/// Wrap a future inside a match statement to handle errors and send the response to the Batcher
async
fn
wrap_future
(
async
fn
wrap_future
(
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
),
ClientError
>>
,
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
GeneratedText
>
,
Option
<
Batch
>
),
ClientError
>>
,
request_ids
:
Vec
<
u64
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
db
:
&
Db
,
)
->
Option
<
Batch
>
{
)
->
Option
<
Batch
>
{
match
future
.await
{
match
future
.await
{
Ok
((
generated_texts
,
next_batch
))
=>
{
Ok
((
generated_texts
,
next_batch
))
=>
{
send_generated
(
generated_texts
,
db
);
send_generated
(
generated_texts
,
entries
);
next_batch
next_batch
}
}
// If we have an error, we discard the whole batch
// If we have an error, we discard the whole batch
Err
(
err
)
=>
{
Err
(
err
)
=>
{
send_error
(
err
,
request_ids
,
db
);
send_error
(
err
,
entries
);
None
None
}
}
}
}
}
}
/// Send errors to the Batcher for all `request_ids`
/// Send errors to the Batcher for all `entries`
fn
send_error
(
error
:
ClientError
,
request_ids
:
Vec
<
u64
>
,
db
:
&
Db
)
{
fn
send_error
(
error
:
ClientError
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
)
{
request_ids
.into_iter
()
.for_each
(|
id
|
{
entries
.drain
()
.for_each
(|(
_
,
entry
)|
{
// We can `expect` here as the request id should always be in the DB
let
entry
=
db
.remove
(
&
id
)
.expect
(
"ID not found in db. This is a bug."
);
// unwrap_or is valid here as we don't care if the receiver is gone.
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Err
(
error
.clone
()))
.unwrap_or
(());
entry
.response_tx
.send
(
Err
(
error
.clone
()))
.unwrap_or
(());
});
});
}
}
/// Send `generated_text` to the Batcher for all `finished`
/// Send `generated_text` to the Batcher for all `finished`
fn
send_generated
(
finished
:
Vec
<
GeneratedText
>
,
db
:
&
Db
)
{
fn
send_generated
(
finished
:
Vec
<
GeneratedText
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
)
{
finished
.into_iter
()
.for_each
(|
output
|
{
finished
.into_iter
()
.for_each
(|
output
|
{
// We can `expect` here as the request id should always be in the
DB
// We can `expect` here as the request id should always be in the
entries
let
entry
=
db
let
entry
=
entries
.remove
(
&
output
.request
.unwrap
()
.id
)
.remove
(
&
output
.request
.unwrap
()
.id
)
.expect
(
"ID not found in
db
. This is a bug."
);
.expect
(
"ID not found in
entries
. This is a bug."
);
let
response
=
InferResponse
{
let
response
=
InferResponse
{
output_text
:
output
.output_text
,
output_text
:
output
.output_text
,
...
...
router/src/db.rs
View file @
1539d3cb
use
crate
::
InferResponse
;
/// This code is massively inspired by Tokio mini-redis
/// This code is massively inspired by Tokio mini-redis
use
crate
::
InferResponse
;
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
nohash_hasher
::{
BuildNoHashHasher
,
IntMap
};
use
parking_lot
::
Mutex
;
use
parking_lot
::
Mutex
;
use
std
::
collections
::
BTreeMap
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -112,18 +113,12 @@ impl Db {
...
@@ -112,18 +113,12 @@ impl Db {
state
.entries
.insert
(
id
,
entry
);
state
.entries
.insert
(
id
,
entry
);
}
}
/// Remove an entry from the database if it exists
pub
(
crate
)
fn
remove
(
&
self
,
id
:
&
u64
)
->
Option
<
Entry
>
{
let
mut
state
=
self
.shared.state
.lock
();
state
.entries
.remove
(
id
)
}
// Get the next batch
// Get the next batch
pub
(
crate
)
fn
next_batch
(
pub
(
crate
)
fn
next_batch
(
&
self
,
&
self
,
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
max_size
:
usize
,
max_size
:
usize
,
)
->
Option
<
(
Vec
<
u64
>
,
Batch
)
>
{
)
->
Option
<
(
IntMap
<
u64
,
Entry
>
,
Batch
)
>
{
// Acquire lock
// Acquire lock
let
mut
state
=
self
.shared.state
.lock
();
let
mut
state
=
self
.shared.state
.lock
();
...
@@ -135,13 +130,19 @@ impl Db {
...
@@ -135,13 +130,19 @@ impl Db {
return
None
;
return
None
;
}
}
}
}
// Batch size
let
size
=
requests
.len
();
let
mut
entries
=
IntMap
::
with_capacity_and_hasher
(
size
,
BuildNoHashHasher
::
default
());
ids
.iter
()
.for_each
(|
id
|
{
ids
.iter
()
.for_each
(|
id
|
{
// Set batch_time for each request
// Remove entry from db
state
.entries
.get_mut
(
id
)
.unwrap
()
.batch_time
=
Some
(
Instant
::
now
());
let
mut
entry
=
state
.entries
.remove
(
id
)
.unwrap
();
// Set batch_time
entry
.batch_time
=
Some
(
Instant
::
now
());
// Insert in entries IntMap
entries
.insert
(
*
id
,
entry
);
});
});
// Batch size
let
size
=
requests
.len
();
let
batch
=
Batch
{
let
batch
=
Batch
{
id
:
state
.next_batch_id
,
id
:
state
.next_batch_id
,
requests
,
requests
,
...
@@ -152,7 +153,7 @@ impl Db {
...
@@ -152,7 +153,7 @@ impl Db {
// Increment batch id
// Increment batch id
state
.next_batch_id
+=
1
;
state
.next_batch_id
+=
1
;
return
Some
((
id
s
,
batch
));
return
Some
((
entrie
s
,
batch
));
}
}
None
None
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment