Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c4fb09f2
Unverified
Commit
c4fb09f2
authored
Apr 26, 2023
by
Nicolas Patry
Committed by
GitHub
Apr 26, 2023
Browse files
feat(router): add tests to validation (#237)
parent
77758f60
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
102 additions
and
10 deletions
+102
-10
.github/workflows/tests.yaml
.github/workflows/tests.yaml
+3
-0
router/src/lib.rs
router/src/lib.rs
+16
-0
router/src/queue.rs
router/src/queue.rs
+11
-0
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+71
-10
No files found.
.github/workflows/tests.yaml
View file @
c4fb09f2
...
...
@@ -67,6 +67,9 @@ jobs:
run
:
|
pip install pytest
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -sv server/tests
-
name
:
Run Clippy
run
:
|
cargo clippy
-
name
:
Run Rust tests
run
:
|
cargo test
...
...
router/src/lib.rs
View file @
c4fb09f2
...
...
@@ -276,3 +276,19 @@ pub(crate) struct ErrorResponse {
pub
error
:
String
,
pub
error_type
:
String
,
}
#[cfg(test)]
mod
tests
{
use
std
::
io
::
Write
;
use
tokenizers
::
Tokenizer
;
pub
(
crate
)
async
fn
get_tokenizer
()
->
Tokenizer
{
if
!
std
::
path
::
Path
::
new
(
"tokenizer.json"
)
.exists
(){
let
content
=
reqwest
::
get
(
"https://huggingface.co/gpt2/raw/main/tokenizer.json"
)
.await
.unwrap
()
.bytes
()
.await
.unwrap
();
let
mut
file
=
std
::
fs
::
File
::
create
(
"tokenizer.json"
)
.unwrap
();
file
.write_all
(
&
content
)
.unwrap
();
}
Tokenizer
::
from_file
(
"tokenizer.json"
)
.unwrap
()
}
}
router/src/queue.rs
View file @
c4fb09f2
...
...
@@ -141,6 +141,7 @@ impl State {
// Get the next batch
fn
next_batch
(
&
mut
self
,
min_size
:
Option
<
usize
>
,
token_budget
:
u32
)
->
Option
<
NextBatch
>
{
if
self
.entries
.is_empty
()
{
return
None
;
}
...
...
@@ -430,7 +431,17 @@ mod tests {
let
(
entry3
,
_
guard3
)
=
default_entry
();
queue
.append
(
entry3
);
// Not enough requests pending
assert
!
(
queue
.next_batch
(
Some
(
2
),
2
)
.await
.is_none
());
// Not enough token budget
assert
!
(
queue
.next_batch
(
Some
(
1
),
0
)
.await
.is_none
());
// Ok
let
(
entries2
,
batch2
,
_
)
=
queue
.next_batch
(
Some
(
1
),
2
)
.await
.unwrap
();
assert_eq!
(
entries2
.len
(),
1
);
assert
!
(
entries2
.contains_key
(
&
2
));
assert
!
(
entries2
.get
(
&
2
)
.unwrap
()
.batch_time
.is_some
());
assert_eq!
(
batch2
.id
,
1
);
assert_eq!
(
batch2
.size
,
1
);
}
#[tokio::test]
...
...
router/src/server.rs
View file @
c4fb09f2
...
...
@@ -741,3 +741,4 @@ impl From<InferError> for Event {
.unwrap
()
}
}
router/src/validation.rs
View file @
c4fb09f2
...
...
@@ -382,7 +382,8 @@ pub enum ValidationError {
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
std
::
io
::
Write
;
use
crate
::
default_parameters
;
use
crate
::
tests
::
get_tokenizer
;
#[tokio::test]
async
fn
test_validation_max_new_tokens
(){
...
...
@@ -401,15 +402,6 @@ mod tests{
}
}
async
fn
get_tokenizer
()
->
Tokenizer
{
if
!
std
::
path
::
Path
::
new
(
"tokenizer.json"
)
.exists
(){
let
content
=
reqwest
::
get
(
"https://huggingface.co/gpt2/raw/main/tokenizer.json"
)
.await
.unwrap
()
.bytes
()
.await
.unwrap
();
let
mut
file
=
std
::
fs
::
File
::
create
(
"tokenizer.json"
)
.unwrap
();
file
.write_all
(
&
content
)
.unwrap
();
}
Tokenizer
::
from_file
(
"tokenizer.json"
)
.unwrap
()
}
#[tokio::test]
async
fn
test_validation_input_length
(){
let
tokenizer
=
Some
(
get_tokenizer
()
.await
);
...
...
@@ -426,4 +418,73 @@ mod tests{
_
=>
panic!
(
"Unexpected not max new tokens"
)
}
}
#[tokio::test]
async
fn
test_validation_best_of_sampling
(){
let
tokenizer
=
Some
(
get_tokenizer
()
.await
);
let
max_best_of
=
2
;
let
max_stop_sequence
=
3
;
let
max_input_length
=
4
;
let
max_total_tokens
=
5
;
let
workers
=
1
;
let
validation
=
Validation
::
new
(
workers
,
tokenizer
,
max_best_of
,
max_stop_sequence
,
max_input_length
,
max_total_tokens
);
match
validation
.validate
(
GenerateRequest
{
inputs
:
"Hello"
.to_string
(),
parameters
:
GenerateParameters
{
best_of
:
Some
(
2
),
do_sample
:
false
,
..
default_parameters
()
}
})
.await
{
Err
(
ValidationError
::
BestOfSampling
)
=>
(),
_
=>
panic!
(
"Unexpected not best of sampling"
)
}
}
#[tokio::test]
async
fn
test_validation_top_p
(){
let
tokenizer
=
Some
(
get_tokenizer
()
.await
);
let
max_best_of
=
2
;
let
max_stop_sequence
=
3
;
let
max_input_length
=
4
;
let
max_total_tokens
=
5
;
let
workers
=
1
;
let
validation
=
Validation
::
new
(
workers
,
tokenizer
,
max_best_of
,
max_stop_sequence
,
max_input_length
,
max_total_tokens
);
match
validation
.validate
(
GenerateRequest
{
inputs
:
"Hello"
.to_string
(),
parameters
:
GenerateParameters
{
top_p
:
Some
(
1.0
),
..
default_parameters
()
}
})
.await
{
Err
(
ValidationError
::
TopP
)
=>
(),
_
=>
panic!
(
"Unexpected top_p"
)
}
match
validation
.validate
(
GenerateRequest
{
inputs
:
"Hello"
.to_string
(),
parameters
:
GenerateParameters
{
top_p
:
Some
(
0.99
),
max_new_tokens
:
1
,
..
default_parameters
()
}
})
.await
{
Ok
(
_
)
=>
(),
_
=>
panic!
(
"Unexpected top_p error"
)
}
let
valid_request
=
validation
.validate
(
GenerateRequest
{
inputs
:
"Hello"
.to_string
(),
parameters
:
GenerateParameters
{
top_p
:
None
,
max_new_tokens
:
1
,
..
default_parameters
()
}
})
.await
.unwrap
();
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!
(
valid_request
.parameters.top_p
,
1.0
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment