Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c8378933
Commit
c8378933
authored
Oct 21, 2022
by
OlivierDehaene
Browse files
feat(router): Add max_waiting_tokens
parent
895a341d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
121 additions
and
79 deletions
+121
-79
launcher/src/main.rs
launcher/src/main.rs
+6
-6
router/src/batcher.rs
router/src/batcher.rs
+36
-23
router/src/db.rs
router/src/db.rs
+10
-18
router/src/lib.rs
router/src/lib.rs
+1
-3
router/src/main.rs
router/src/main.rs
+6
-9
router/src/server.rs
router/src/server.rs
+58
-19
router/src/validation.rs
router/src/validation.rs
+4
-1
No files found.
launcher/src/main.rs
View file @
c8378933
...
...
@@ -28,8 +28,8 @@ struct Args {
max_input_length
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
#[clap(default_value
=
"
5
"
,
long,
env)]
max_waiting_t
ime
:
u64
,
#[clap(default_value
=
"
20
"
,
long,
env)]
max_waiting_t
okens
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
port
:
u16
,
#[clap(default_value
=
"/tmp/text-generation-server"
,
long,
env)]
...
...
@@ -41,7 +41,7 @@ struct Args {
}
fn
main
()
->
ExitCode
{
tracing_subscriber
::
fmt
::
init
();
tracing_subscriber
::
fmt
()
.compact
()
.with_ansi
(
false
)
.
init
();
// Pattern match configuration
let
Args
{
...
...
@@ -51,7 +51,7 @@ fn main() -> ExitCode {
max_concurrent_requests
,
max_input_length
,
max_batch_size
,
max_waiting_t
ime
,
max_waiting_t
okens
,
port
,
shard_uds_path
,
master_addr
,
...
...
@@ -148,8 +148,8 @@ fn main() -> ExitCode {
&
max_input_length
.to_string
(),
"--max-batch-size"
,
&
max_batch_size
.to_string
(),
"--max-waiting-t
ime
"
,
&
max_waiting_t
ime
.to_string
(),
"--max-waiting-t
okens
"
,
&
max_waiting_t
okens
.to_string
(),
"--port"
,
&
port
.to_string
(),
"--master-shard-uds-path"
,
...
...
router/src/batcher.rs
View file @
c8378933
...
...
@@ -5,7 +5,6 @@ use axum::http::StatusCode;
use
bloom_inference_client
::{
Batch
,
ClientError
,
GeneratedText
,
ShardedClient
};
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
thiserror
::
Error
;
use
tokio
::
sync
::{
oneshot
,
Notify
};
use
tokio
::
time
::
Instant
;
...
...
@@ -30,7 +29,7 @@ impl Batcher {
pub
(
crate
)
fn
new
(
client
:
ShardedClient
,
max_batch_size
:
usize
,
max_waiting_t
ime
:
Duration
,
max_waiting_t
okens
:
usize
,
)
->
Self
{
// Batcher shared state
let
db
=
Db
::
new
();
...
...
@@ -41,7 +40,7 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic
tokio
::
spawn
(
batching_task
(
max_batch_size
,
max_waiting_t
ime
,
max_waiting_t
okens
,
client
,
db
.clone
(),
shared
.clone
(),
...
...
@@ -55,7 +54,7 @@ impl Batcher {
&
self
,
input_length
:
usize
,
request
:
GenerateRequest
,
)
->
Result
<
String
,
InferError
>
{
)
->
Result
<
InferResponse
,
InferError
>
{
// One shot channel to communicate with the background batching task
let
(
response_tx
,
response_rx
)
=
oneshot
::
channel
();
...
...
@@ -65,6 +64,7 @@ impl Batcher {
response_tx
,
input_length
,
time
:
Instant
::
now
(),
batch_time
:
None
,
});
// Notify the background task that we have a new entry in the database that needs
...
...
@@ -87,7 +87,7 @@ impl Batcher {
#[instrument(skip(client,
db,
shared))]
async
fn
batching_task
(
max_batch_size
:
usize
,
max_waiting_t
ime
:
Duration
,
max_waiting_t
okens
:
usize
,
client
:
ShardedClient
,
db
:
Db
,
shared
:
Arc
<
Shared
>
,
...
...
@@ -103,8 +103,10 @@ async fn batching_task(
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
if
let
Some
((
request_ids
,
batch
))
=
db
.next_batch
(
None
,
max_batch_size
,
None
)
{
let
mut
waiting_tokens
=
0
;
if
let
Some
((
request_ids
,
batch
))
=
db
.next_batch
(
None
,
max_batch_size
)
{
let
mut
cached_batch
=
wrap_future
(
client
.generate
(
batch
),
request_ids
,
&
db
)
.await
;
waiting_tokens
+=
1
;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
...
...
@@ -116,10 +118,20 @@ async fn batching_task(
// If the current batch is too small, we try to add more requests to it
if
batch_size
<=
limit_min_batch_size
{
// Get the next batch from the DB that meet our minimum size criteria
let
min_size
=
match
waiting_tokens
{
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_
if
waiting_tokens
>=
max_waiting_tokens
=>
None
,
// Minimum size criteria
_
=>
Some
(
limit_min_batch_size
as
usize
),
};
// Try to get a new batch
if
let
Some
((
new_request_ids
,
new_batch
))
=
db
.next_batch
(
Some
(
limit_min_batch_size
as
usize
)
,
max_batch_size
,
None
)
db
.next_batch
(
min_size
,
max_batch_size
)
{
// Reset waiting counter
waiting_tokens
=
0
;
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
wrap_future
(
client
.generate
(
new_batch
),
new_request_ids
,
&
db
)
.await
;
...
...
@@ -129,24 +141,11 @@ async fn batching_task(
batches
.push
(
new_cached_batch
);
}
}
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else
if
let
Some
((
new_request_ids
,
new_batch
))
=
db
.next_batch
(
None
,
max_batch_size
,
Some
(
max_waiting_time
))
{
let
new_cached_batch
=
wrap_future
(
client
.generate
(
new_batch
),
new_request_ids
,
&
db
)
.await
;
// Extend current batch with the new batch
if
let
Some
(
new_cached_batch
)
=
new_cached_batch
{
request_ids
.extend
(
new_cached_batch
.requests
.iter
()
.map
(|
req
|
req
.id
));
batches
.push
(
new_cached_batch
);
}
}
}
cached_batch
=
wrap_future
(
client
.generate_with_cache
(
batches
),
request_ids
,
&
db
)
.await
;
waiting_tokens
+=
1
;
}
}
}
...
...
@@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let
entry
=
db
.remove
(
&
output
.request
.unwrap
()
.id
)
.expect
(
"ID not found in db. This is a bug."
);
let
response
=
InferResponse
{
output
:
output
.output
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
end
:
Instant
::
now
(),
};
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send
(
Ok
(
output
.output
))
.unwrap_or
(());
entry
.response_tx
.send
(
Ok
(
response
))
.unwrap_or
(());
});
}
#[derive(Debug)]
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
output
:
String
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
end
:
Instant
,
}
#[derive(Debug,
Error)]
pub
enum
InferError
{
#[error(
"Request failed during generation: {0}"
)]
...
...
router/src/db.rs
View file @
c8378933
use
crate
::
InferResponse
;
/// This code is massively inspired by Tokio mini-redis
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
bloom_inference_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
parking_lot
::
Mutex
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
time
::
Instant
;
...
...
@@ -14,11 +14,13 @@ pub(crate) struct Entry {
/// Request
pub
request
:
GenerateRequest
,
/// Response sender to communicate between the Batcher and the batching_task
pub
response_tx
:
Sender
<
Result
<
String
,
ClientError
>>
,
pub
response_tx
:
Sender
<
Result
<
InferResponse
,
ClientError
>>
,
/// Number of tokens in the input
pub
input_length
:
usize
,
/// Instant when this entry was created
pub
time
:
Instant
,
/// Instant when this entry was added to a batch
pub
batch_time
:
Option
<
Instant
>
,
}
/// Request Database
...
...
@@ -51,11 +53,7 @@ struct State {
impl
State
{
/// Get the next requests
fn
next_requests
(
&
self
,
max_size
:
usize
,
min_waiting_time
:
Option
<
Duration
>
,
)
->
Option
<
(
Vec
<
u64
>
,
Vec
<
Request
>
)
>
{
fn
next_requests
(
&
self
,
max_size
:
usize
)
->
Option
<
(
Vec
<
u64
>
,
Vec
<
Request
>
)
>
{
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let
mut
requests
=
Vec
::
new
();
let
mut
ids
=
Vec
::
new
();
...
...
@@ -67,15 +65,6 @@ impl State {
// Take max_size
.take
(
max_size
)
{
if
let
Some
(
min_waiting_time
)
=
min_waiting_time
{
// Only take entries that waited for at least min_waiting_time
if
entry
.time
.elapsed
()
<
min_waiting_time
{
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break
;
}
}
requests
.push
(
Request
{
id
:
*
id
,
inputs
:
entry
.request.inputs
.clone
(),
...
...
@@ -134,19 +123,22 @@ impl Db {
&
self
,
min_size
:
Option
<
usize
>
,
max_size
:
usize
,
min_waiting_time
:
Option
<
Duration
>
,
)
->
Option
<
(
Vec
<
u64
>
,
Batch
)
>
{
// Acquire lock
let
mut
state
=
self
.shared.state
.lock
();
// Get requests from the database
if
let
Some
((
ids
,
requests
))
=
state
.next_requests
(
max_size
,
min_waiting_time
)
{
if
let
Some
((
ids
,
requests
))
=
state
.next_requests
(
max_size
)
{
if
let
Some
(
min_size
)
=
min_size
{
// If min_size is set, only return a batch if there are enough requests
if
requests
.len
()
<
min_size
{
return
None
;
}
}
ids
.iter
()
.for_each
(|
id
|
{
// Set batch_time for each request
state
.entries
.get_mut
(
id
)
.unwrap
()
.batch_time
=
Some
(
Instant
::
now
());
});
// Batch size
let
size
=
requests
.len
();
...
...
router/src/lib.rs
View file @
c8378933
...
...
@@ -4,7 +4,7 @@ mod db;
pub
mod
server
;
mod
validation
;
use
batcher
::
Batcher
;
use
batcher
::
{
Batcher
,
InferResponse
}
;
use
db
::{
Db
,
Entry
};
use
serde
::{
Deserialize
,
Serialize
};
use
validation
::
Validation
;
...
...
@@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
pub
(
crate
)
struct
GeneratedText
{
pub
generated_text
:
String
,
}
pub
(
crate
)
type
GenerateResponse
=
Vec
<
GeneratedText
>
;
router/src/main.rs
View file @
c8378933
...
...
@@ -2,7 +2,6 @@
use
bloom_inference_client
::
ShardedClient
;
use
clap
::
Parser
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
time
::
Duration
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
...
...
@@ -16,8 +15,8 @@ struct Args {
max_input_length
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
#[clap(default_value
=
"
5
"
,
long,
env)]
max_waiting_t
ime
:
u64
,
#[clap(default_value
=
"
20
"
,
long,
env)]
max_waiting_t
okens
:
usize
,
#[clap(default_value
=
"3000"
,
long,
short,
env)]
port
:
u16
,
#[clap(default_value
=
"/tmp/bloom-inference-0"
,
long,
env)]
...
...
@@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests
,
max_input_length
,
max_batch_size
,
max_waiting_t
ime
,
max_waiting_t
okens
,
port
,
master_shard_uds_path
,
tokenizer_name
,
validation_workers
,
}
=
args
;
tracing_subscriber
::
fmt
()
.compact
()
.with_ansi
(
false
)
.init
();
if
validation_workers
==
1
{
panic!
(
"validation_workers must be > 0"
);
}
let
max_waiting_time
=
Duration
::
from_secs
(
max_waiting_time
);
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
...
...
@@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
.build
()
.unwrap
()
.block_on
(
async
{
tracing_subscriber
::
fmt
::
init
();
// Instantiate sharded client from the master unix socket
let
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
.await
...
...
@@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests
,
max_input_length
,
max_batch_size
,
max_waiting_t
ime
,
max_waiting_t
okens
,
sharded_client
,
tokenizer
,
validation_workers
,
...
...
router/src/server.rs
View file @
c8378933
use
crate
::{
Batcher
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
GeneratedText
,
Validation
,
};
use
crate
::{
Batcher
,
GenerateParameters
,
GenerateRequest
,
GeneratedText
,
Validation
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::
StatusCode
;
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
use
axum
::
response
::
IntoResponse
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
bloom_inference_client
::
ShardedClient
;
use
std
::
net
::
SocketAddr
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
signal
;
use
tokio
::
sync
::
Semaphore
;
...
...
@@ -59,12 +57,21 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String
}
/// Generate method
#[instrument(skip(state),
fields(time,
time_per_token))]
#[instrument(
skip(state),
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token
)
)]
async
fn
generate
(
state
:
Extension
<
ServerState
>
,
req
:
Json
<
GenerateRequest
>
,
)
->
Result
<
Json
<
Generate
Response
>
,
(
StatusCode
,
String
)
>
{
let
start
=
Instant
::
now
();
)
->
Result
<
impl
Into
Response
,
(
StatusCode
,
String
)
>
{
let
start
_time
=
Instant
::
now
();
// Limit concurrent requests by acquiring a permit from the semaphore
let
_
permit
=
state
.limit_concurrent_requests
.try_acquire
()
.map_err
(|
_
|
{
(
...
...
@@ -84,19 +91,51 @@ async fn generate(
.await
?
;
// Inference
let
generated_text
=
state
.batcher
.infer
(
input_length
,
validated_request
)
.await
?
;
let
response
=
state
.batcher
.infer
(
input_length
,
validated_request
)
.await
?
;
// Timings
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
response
.queued
-
start_time
;
let
queue_time
=
response
.start
-
response
.queued
;
let
inference_time
=
response
.end
-
response
.start
;
let
time_per_token
=
inference_time
/
req
.parameters.max_new_tokens
;
// Headers
let
mut
headers
=
HeaderMap
::
new
();
headers
.insert
(
"x-total-time"
,
total_time
.as_millis
()
.to_string
()
.parse
()
.unwrap
(),
);
headers
.insert
(
"x-validation-time"
,
validation_time
.as_millis
()
.to_string
()
.parse
()
.unwrap
(),
);
headers
.insert
(
"x-queue-time"
,
queue_time
.as_millis
()
.to_string
()
.parse
()
.unwrap
(),
);
headers
.insert
(
"x-inference-time"
,
inference_time
.as_millis
()
.to_string
()
.parse
()
.unwrap
(),
);
headers
.insert
(
"x-time-per-token"
,
time_per_token
.as_millis
()
.to_string
()
.parse
()
.unwrap
(),
);
// Tracing metadata
tracing
::
Span
::
current
()
.record
(
"time"
,
format!
(
"{:?}"
,
start
.elapsed
()
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
start
.elapsed
()
/
req
.parameters.max_new_tokens
),
);
tracing
::
info!
(
"
response: {}"
,
generated_tex
t
);
tracing
::
Span
::
current
()
.record
(
"
total_
time"
,
format!
(
"{:?}"
,
total_time
));
tracing
::
Span
::
current
()
.record
(
"validation_time"
,
format!
(
"{:?}"
,
validation_time
));
tracing
::
Span
::
current
()
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
tracing
::
Span
::
current
()
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
)
);
tracing
::
info!
(
"
Output: {}"
,
response
.outpu
t
);
// Send response
let
response
=
vec!
[
GeneratedText
{
generated_text
}];
Ok
(
Json
(
response
))
let
response
=
vec!
[
GeneratedText
{
generated_text
:
response
.output
,
}];
Ok
((
headers
,
Json
(
response
)))
}
/// Serving method
...
...
@@ -105,14 +144,14 @@ pub async fn run(
max_concurrent_requests
:
usize
,
max_input_length
:
usize
,
max_batch_size
:
usize
,
max_waiting_t
ime
:
Duration
,
max_waiting_t
okens
:
usize
,
client
:
ShardedClient
,
tokenizer
:
Tokenizer
,
validation_workers
:
usize
,
addr
:
SocketAddr
,
)
{
// Create state
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
,
max_waiting_t
ime
);
let
batcher
=
Batcher
::
new
(
client
,
max_batch_size
,
max_waiting_t
okens
);
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_input_length
);
let
shared_state
=
ServerState
{
validation
,
...
...
router/src/validation.rs
View file @
c8378933
...
...
@@ -127,7 +127,10 @@ fn validation_worker(
if
input_length
>
max_input_length
{
response_tx
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
)))
.send
(
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
,
)))
.unwrap_or
(());
continue
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment