Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5437d49b
Unverified
Commit
5437d49b
authored
Feb 15, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 15, 2023
Browse files
feat(router): add max_total_tokens and empty_input validation (#68)
closes #65
parent
68455353
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
17 deletions
+74
-17
router/src/main.rs
router/src/main.rs
+8
-0
router/src/server.rs
router/src/server.rs
+10
-2
router/src/validation.rs
router/src/validation.rs
+56
-15
No files found.
router/src/main.rs
View file @
5437d49b
...
...
@@ -20,8 +20,12 @@ use tracing_subscriber::{EnvFilter, Layer};
struct
Args
{
#[clap(default_value
=
"128"
,
long,
env)]
max_concurrent_requests
:
usize
,
#[clap(default_value
=
"4"
,
long,
env)]
max_stop_sequences
:
usize
,
#[clap(default_value
=
"1000"
,
long,
env)]
max_input_length
:
usize
,
#[clap(default_value
=
"1512"
,
long,
env)]
max_total_tokens
:
usize
,
#[clap(default_value
=
"32"
,
long,
env)]
max_batch_size
:
usize
,
#[clap(default_value
=
"20"
,
long,
env)]
...
...
@@ -46,7 +50,9 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration
let
Args
{
max_concurrent_requests
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
max_batch_size
,
max_waiting_tokens
,
port
,
...
...
@@ -92,7 +98,9 @@ fn main() -> Result<(), std::io::Error> {
// Run server
server
::
run
(
max_concurrent_requests
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
max_batch_size
,
max_waiting_tokens
,
sharded_client
,
...
...
router/src/server.rs
View file @
5437d49b
...
...
@@ -28,7 +28,7 @@ use utoipa_swagger_ui::SwaggerUi;
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead i
f
check if the gRPC channels are still healthy.
// What we should do instead i
s
check if the gRPC channels are still healthy.
// Send a small inference request
infer
...
...
@@ -291,7 +291,9 @@ async fn generate_stream(
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
max_concurrent_requests
:
usize
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
max_batch_size
:
usize
,
max_waiting_tokens
:
usize
,
client
:
ShardedClient
,
...
...
@@ -333,7 +335,13 @@ pub async fn run(
struct
ApiDoc
;
// Create state
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_input_length
);
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
);
let
infer
=
Infer
::
new
(
client
,
validation
,
...
...
router/src/validation.rs
View file @
5437d49b
use
crate
::
validation
::
ValidationError
::
EmptyInput
;
/// Payload validation logic
use
crate
::{
GenerateParameters
,
GenerateRequest
};
use
rand
::
rngs
::
ThreadRng
;
...
...
@@ -8,9 +9,6 @@ use tokenizers::tokenizer::Tokenizer;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tracing
::{
instrument
,
Span
};
const
MAX_MAX_NEW_TOKENS
:
u32
=
512
;
const
MAX_STOP_SEQUENCES
:
usize
=
4
;
/// Validation
#[derive(Debug,
Clone)]
pub
struct
Validation
{
...
...
@@ -19,7 +17,13 @@ pub struct Validation {
}
impl
Validation
{
pub
(
crate
)
fn
new
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_input_length
:
usize
)
->
Self
{
pub
(
crate
)
fn
new
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
)
->
Self
{
// Create channel
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
channel
(
128
);
...
...
@@ -27,7 +31,9 @@ impl Validation {
tokio
::
spawn
(
validation_task
(
workers
,
tokenizer
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
validation_receiver
,
));
...
...
@@ -61,7 +67,9 @@ impl Validation {
async
fn
validation_task
(
workers
:
usize
,
tokenizer
:
Tokenizer
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
)
{
let
mut
workers_senders
=
Vec
::
with_capacity
(
workers
);
...
...
@@ -75,7 +83,13 @@ async fn validation_task(
// Spawn worker
tokio
::
task
::
spawn_blocking
(
move
||
{
validation_worker
(
tokenizer_clone
,
max_input_length
,
worker_receiver
)
validation_worker
(
tokenizer_clone
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
worker_receiver
,
)
});
}
...
...
@@ -95,7 +109,9 @@ async fn validation_task(
/// the tokenizer
fn
validation_worker
(
tokenizer
:
Tokenizer
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
)
{
// Seed rng
...
...
@@ -106,7 +122,15 @@ fn validation_worker(
parent_span
.in_scope
(||
{
response_tx
.send
(
validate
(
request
,
&
tokenizer
,
max_input_length
,
&
mut
rng
)
.map_err
(|
err
|
{
validate
(
request
,
&
tokenizer
,
max_stop_sequences
,
max_input_length
,
max_total_tokens
,
&
mut
rng
,
)
.map_err
(|
err
|
{
tracing
::
error!
(
"{err}"
);
err
}),
...
...
@@ -119,7 +143,9 @@ fn validation_worker(
fn
validate
(
request
:
GenerateRequest
,
tokenizer
:
&
Tokenizer
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
rng
:
&
mut
ThreadRng
,
)
->
Result
<
ValidGenerateRequest
,
ValidationError
>
{
let
GenerateParameters
{
...
...
@@ -161,13 +187,13 @@ fn validate(
}
}
?
;
if
max_new_tokens
==
0
||
max_new_tokens
>
MAX_MAX_NEW_TOKENS
{
return
Err
(
ValidationError
::
MaxNewTokens
(
MAX_MAX_NEW_TOKENS
)
);
if
max_new_tokens
==
0
{
return
Err
(
ValidationError
::
MaxNewTokens
);
}
if
stop_sequences
.len
()
>
MAX_STOP_SEQUENCES
{
if
stop_sequences
.len
()
>
max_stop_sequences
{
return
Err
(
ValidationError
::
StopSequence
(
MAX_STOP_SEQUENCES
,
max_stop_sequences
,
stop_sequences
.len
(),
));
}
...
...
@@ -178,13 +204,24 @@ fn validate(
Some
(
seed
)
=>
seed
,
};
// Check if inputs is empty
if
request
.inputs
.is_empty
()
{
return
Err
(
EmptyInput
);
}
// Get the number of tokens in the input
match
tokenizer
.encode
(
request
.inputs
.clone
(),
true
)
{
Ok
(
encoding
)
=>
{
let
input_length
=
encoding
.len
();
let
total_tokens
=
input_length
+
max_new_tokens
as
usize
;
if
input_length
>
max_input_length
{
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
))
Err
(
ValidationError
::
InputLength
(
max_input_length
,
input_length
))
}
else
if
total_tokens
>
max_total_tokens
{
Err
(
ValidationError
::
MaxTotalTokens
(
max_total_tokens
,
input_length
,
max_new_tokens
,
))
}
else
{
// Return ValidGenerateRequest
let
parameters
=
NextTokenChooserParameters
{
...
...
@@ -236,10 +273,14 @@ pub enum ValidationError {
TopP
,
#[error(
"top_k must be strictly positive"
)]
TopK
,
#[error(
"max_new_tokens must be strictly positive and <= {0}"
)]
MaxNewTokens
(
u32
),
#[error(
"inputs must have less than {1} tokens. Given: {0}"
)]
#[error(
"max_new_tokens must be strictly positive"
)]
MaxNewTokens
,
#[error(
"input tokens + max_new_tokens must be <= {0}. Given: {1} input tokens and {2} max_new_tokens"
)]
MaxTotalTokens
(
usize
,
usize
,
u32
),
#[error(
"inputs must have less than {0} tokens. Given: {1}"
)]
InputLength
(
usize
,
usize
),
#[error(
"inputs cannot be empty"
)]
EmptyInput
,
#[error(
"stop supports up to {0} stop sequences. Given: {1}"
)]
StopSequence
(
usize
,
usize
),
#[error(
"tokenizer error {0}"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment