Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1a2d6825
Unverified
Commit
1a2d6825
authored
Mar 09, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 09, 2023
Browse files
feat: support typical sampling (#114)
closes #112
parent
941cd42e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
55 additions
and
18 deletions
+55
-18
proto/generate.proto
proto/generate.proto
+6
-4
router/src/lib.rs
router/src/lib.rs
+10
-0
router/src/queue.rs
router/src/queue.rs
+1
-0
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+30
-14
server/tests/conftest.py
server/tests/conftest.py
+1
-0
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+6
-0
No files found.
proto/generate.proto
View file @
1a2d6825
...
...
@@ -34,14 +34,16 @@ message NextTokenChooserParameters {
uint32
top_k
=
2
;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float
top_p
=
3
;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float
typical_p
=
4
;
/// apply sampling on the logits
bool
do_sample
=
4
;
bool
do_sample
=
5
;
/// random seed for sampling
uint64
seed
=
5
;
uint64
seed
=
6
;
/// repetition penalty
float
repetition_penalty
=
6
;
float
repetition_penalty
=
7
;
/// token watermarking using "A Watermark for Large Language Models"
bool
watermark
=
7
;
bool
watermark
=
8
;
}
message
StoppingCriteriaParameters
{
...
...
router/src/lib.rs
View file @
1a2d6825
...
...
@@ -41,6 +41,15 @@ pub(crate) struct GenerateParameters {
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0.0
,
maximum
=
1.0
,
nullable
=
true
,
default
=
"null"
,
example
=
0.95
)]
pub
typical_p
:
Option
<
f32
>
,
#[serde(default)]
#[schema(default
=
"false"
,
example
=
true
)]
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
...
...
@@ -72,6 +81,7 @@ fn default_parameters() -> GenerateParameters {
repetition_penalty
:
None
,
top_k
:
None
,
top_p
:
None
,
typical_p
:
None
,
do_sample
:
false
,
max_new_tokens
:
default_max_new_tokens
(),
return_full_text
:
None
,
...
...
router/src/queue.rs
View file @
1a2d6825
...
...
@@ -231,6 +231,7 @@ mod tests {
temperature
:
0.0
,
top_k
:
0
,
top_p
:
0.0
,
typical_p
:
0.0
,
do_sample
:
false
,
seed
:
0
,
repetition_penalty
:
0.0
,
...
...
router/src/server.rs
View file @
1a2d6825
...
...
@@ -68,6 +68,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
repetition_penalty
:
None
,
top_k
:
None
,
top_p
:
None
,
typical_p
:
None
,
do_sample
:
false
,
max_new_tokens
:
1
,
return_full_text
:
None
,
...
...
router/src/validation.rs
View file @
1a2d6825
...
...
@@ -153,6 +153,7 @@ fn validate(
repetition_penalty
,
top_k
,
top_p
,
typical_p
,
do_sample
,
max_new_tokens
,
stop
:
stop_sequences
,
...
...
@@ -171,22 +172,34 @@ fn validate(
return
Err
(
ValidationError
::
RepetitionPenalty
);
}
let
top_p
=
top_p
.unwrap_or
(
1.0
);
if
top_p
<=
0.0
||
top_p
>
1.0
{
return
Err
(
ValidationError
::
TopP
);
}
// Different because the proto default value is 0 while it is not a valid value
// Different because the proto default value is not a valid value
// for the user
let
top_k
:
u32
=
match
top_k
{
None
=>
Ok
(
0
),
Some
(
top_k
)
=>
{
if
top_k
<=
0
{
let
top_p
=
top_p
.map
(|
value
|
{
if
value
<=
0.0
||
value
>=
1.0
{
return
Err
(
ValidationError
::
TopP
);
}
Ok
(
value
)
})
.unwrap_or
(
Ok
(
1.0
))
?
;
let
typical_p
=
typical_p
.map
(|
value
|
{
if
value
<=
0.0
||
value
>=
1.0
{
return
Err
(
ValidationError
::
TypicalP
);
}
Ok
(
value
)
})
.unwrap_or
(
Ok
(
1.0
))
?
;
let
top_k
:
u32
=
top_k
.map
(|
value
|
{
if
value
<=
0
{
return
Err
(
ValidationError
::
TopK
);
}
Ok
(
top_k
as
u32
)
}
}
?
;
Ok
(
value
as
u32
)
}
)
.unwrap_or
(
Ok
(
0
))
?
;
if
max_new_tokens
==
0
{
return
Err
(
ValidationError
::
MaxNewTokens
);
...
...
@@ -231,6 +244,7 @@ fn validate(
repetition_penalty
,
top_k
,
top_p
,
typical_p
,
do_sample
,
seed
,
watermark
,
...
...
@@ -275,10 +289,12 @@ pub enum ValidationError {
Temperature
,
#[error(
"`repetition_penalty` must be strictly positive"
)]
RepetitionPenalty
,
#[error(
"`top_p` must be > 0.0 and <
=
1.0"
)]
#[error(
"`top_p` must be > 0.0 and < 1.0"
)]
TopP
,
#[error(
"`top_k` must be strictly positive"
)]
TopK
,
#[error(
"`typical_p` must be > 0.0 and < 1.0"
)]
TypicalP
,
#[error(
"`max_new_tokens` must be strictly positive"
)]
MaxNewTokens
,
#[error(
"`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`"
)]
...
...
server/tests/conftest.py
View file @
1a2d6825
...
...
@@ -10,6 +10,7 @@ def default_pb_parameters():
repetition_penalty
=
1.0
,
top_k
=
0
,
top_p
=
1.0
,
typical_p
=
1.0
,
do_sample
=
False
,
)
...
...
server/text_generation_server/utils/tokens.py
View file @
1a2d6825
...
...
@@ -6,6 +6,7 @@ from transformers import (
TemperatureLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
TypicalLogitsWarper
,
RepetitionPenaltyLogitsProcessor
,
PreTrainedTokenizerBase
,
)
...
...
@@ -41,6 +42,7 @@ class NextTokenChooser:
repetition_penalty
=
1.0
,
top_k
=
None
,
top_p
=
None
,
typical_p
=
None
,
do_sample
=
False
,
seed
=
0
,
device
=
"cpu"
,
...
...
@@ -64,6 +66,9 @@ class NextTokenChooser:
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
sampling
=
True
if
typical_p
is
not
None
and
typical_p
<
1.0
:
warpers
.
append
(
TypicalLogitsWarper
(
mass
=
typical_p
))
sampling
=
True
self
.
warpers
=
warpers
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
...
...
@@ -92,6 +97,7 @@ class NextTokenChooser:
repetition_penalty
=
pb
.
repetition_penalty
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
typical_p
=
pb
.
typical_p
,
do_sample
=
pb
.
do_sample
,
seed
=
pb
.
seed
,
device
=
device
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment