Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
313194f6
Unverified
Commit
313194f6
authored
Feb 01, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 01, 2023
Browse files
feat(server): support repetition penalty (#47)
parent
2ad895a6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
30 additions
and
3 deletions
+30
-3
README.md
README.md
+1
-1
proto/generate.proto
proto/generate.proto
+2
-0
router/src/lib.rs
router/src/lib.rs
+6
-0
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+7
-0
server/tests/conftest.py
server/tests/conftest.py
+1
-0
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+1
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+3
-1
server/text_generation/utils.py
server/text_generation/utils.py
+8
-0
No files found.
README.md
View file @
313194f6
...
@@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
...
@@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
Quantization with
[
bitsandbytes
](
https://github.com/TimDettmers/bitsandbytes
)
-
Quantization with
[
bitsandbytes
](
https://github.com/TimDettmers/bitsandbytes
)
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
-
Logits warpers (temperature scaling, topk ...)
-
Logits warpers (temperature scaling, topk
, repetition penalty
...)
-
Stop sequences
-
Stop sequences
-
Log probabilities
-
Log probabilities
...
...
proto/generate.proto
View file @
313194f6
...
@@ -38,6 +38,8 @@ message NextTokenChooserParameters {
...
@@ -38,6 +38,8 @@ message NextTokenChooserParameters {
bool
do_sample
=
4
;
bool
do_sample
=
4
;
/// random seed for sampling
/// random seed for sampling
uint64
seed
=
5
;
uint64
seed
=
5
;
/// repetition penalty
float
repetition_penalty
=
6
;
}
}
message
StoppingCriteriaParameters
{
message
StoppingCriteriaParameters
{
...
...
router/src/lib.rs
View file @
313194f6
...
@@ -13,6 +13,8 @@ use validation::Validation;
...
@@ -13,6 +13,8 @@ use validation::Validation;
pub
(
crate
)
struct
GenerateParameters
{
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default
=
"default_temperature"
)]
#[serde(default
=
"default_temperature"
)]
pub
temperature
:
f32
,
pub
temperature
:
f32
,
#[serde(default
=
"default_repetition_penalty"
)]
pub
repetition_penalty
:
f32
,
#[serde(default
=
"default_top_k"
)]
#[serde(default
=
"default_top_k"
)]
pub
top_k
:
i32
,
pub
top_k
:
i32
,
#[serde(default
=
"default_top_p"
)]
#[serde(default
=
"default_top_p"
)]
...
@@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
...
@@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
fn
default_temperature
()
->
f32
{
fn
default_temperature
()
->
f32
{
1.0
1.0
}
}
fn
default_repetition_penalty
()
->
f32
{
1.0
}
fn
default_top_k
()
->
i32
{
fn
default_top_k
()
->
i32
{
0
0
...
@@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
...
@@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
fn
default_parameters
()
->
GenerateParameters
{
fn
default_parameters
()
->
GenerateParameters
{
GenerateParameters
{
GenerateParameters
{
temperature
:
default_temperature
(),
temperature
:
default_temperature
(),
repetition_penalty
:
default_repetition_penalty
(),
top_k
:
default_top_k
(),
top_k
:
default_top_k
(),
top_p
:
default_top_p
(),
top_p
:
default_top_p
(),
do_sample
:
default_do_sample
(),
do_sample
:
default_do_sample
(),
...
...
router/src/server.rs
View file @
313194f6
...
@@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
...
@@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
inputs
:
"liveness"
.to_string
(),
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
parameters
:
GenerateParameters
{
temperature
:
1.0
,
temperature
:
1.0
,
repetition_penalty
:
1.0
,
top_k
:
0
,
top_k
:
0
,
top_p
:
1.0
,
top_p
:
1.0
,
do_sample
:
false
,
do_sample
:
false
,
...
...
router/src/validation.rs
View file @
313194f6
...
@@ -113,6 +113,9 @@ fn validate(
...
@@ -113,6 +113,9 @@ fn validate(
if
request
.parameters.temperature
<=
0.0
{
if
request
.parameters.temperature
<=
0.0
{
return
Err
(
ValidationError
::
Temperature
);
return
Err
(
ValidationError
::
Temperature
);
}
}
if
request
.parameters.repetition_penalty
<=
0.0
{
return
Err
(
ValidationError
::
RepetitionPenalty
);
}
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
return
Err
(
ValidationError
::
TopP
);
return
Err
(
ValidationError
::
TopP
);
}
}
...
@@ -146,6 +149,7 @@ fn validate(
...
@@ -146,6 +149,7 @@ fn validate(
// Return ValidGenerateRequest
// Return ValidGenerateRequest
let
GenerateParameters
{
let
GenerateParameters
{
temperature
,
temperature
,
repetition_penalty
,
top_k
,
top_k
,
top_p
,
top_p
,
do_sample
,
do_sample
,
...
@@ -156,6 +160,7 @@ fn validate(
...
@@ -156,6 +160,7 @@ fn validate(
let
parameters
=
NextTokenChooserParameters
{
let
parameters
=
NextTokenChooserParameters
{
temperature
,
temperature
,
repetition_penalty
,
top_k
:
top_k
as
u32
,
top_k
:
top_k
as
u32
,
top_p
,
top_p
,
do_sample
,
do_sample
,
...
@@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
...
@@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
pub
enum
ValidationError
{
pub
enum
ValidationError
{
#[error(
"temperature must be strictly positive"
)]
#[error(
"temperature must be strictly positive"
)]
Temperature
,
Temperature
,
#[error(
"repetition_penalty must be strictly positive"
)]
RepetitionPenalty
,
#[error(
"top_p must be > 0.0 and <= 1.0"
)]
#[error(
"top_p must be > 0.0 and <= 1.0"
)]
TopP
,
TopP
,
#[error(
"top_k must be strictly positive"
)]
#[error(
"top_k must be strictly positive"
)]
...
...
server/tests/conftest.py
View file @
313194f6
...
@@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
...
@@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
def
default_pb_parameters
():
def
default_pb_parameters
():
return
generate_pb2
.
NextTokenChooserParameters
(
return
generate_pb2
.
NextTokenChooserParameters
(
temperature
=
1.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
,
top_k
=
0
,
top_k
=
0
,
top_p
=
1.0
,
top_p
=
1.0
,
do_sample
=
False
,
do_sample
=
False
,
...
...
server/text_generation/models/causal_lm.py
View file @
313194f6
...
@@ -336,7 +336,7 @@ class CausalLM(Model):
...
@@ -336,7 +336,7 @@ class CausalLM(Model):
all_input_ids
,
all_input_ids
,
)
in
enumerate
(
iterator
):
)
in
enumerate
(
iterator
):
# Select next token
# Select next token
tokens
,
logprobs
=
next_token_chooser
(
all_input_ids
,
logits
)
tokens
,
logprobs
=
next_token_chooser
(
all_input_ids
.
view
(
1
,
-
1
)
,
logits
)
next_token_id
=
tokens
[
-
1
].
view
(
1
,
1
)
next_token_id
=
tokens
[
-
1
].
view
(
1
,
1
)
# Append next token to all tokens
# Append next token to all tokens
...
...
server/text_generation/models/seq2seq_lm.py
View file @
313194f6
...
@@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
...
@@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
decoder_input_ids
,
decoder_input_ids
,
)
in
enumerate
(
iterator
):
)
in
enumerate
(
iterator
):
# Select next token
# Select next token
next_token_id
,
logprobs
=
next_token_chooser
(
decoder_input_ids
,
logits
)
next_token_id
,
logprobs
=
next_token_chooser
(
decoder_input_ids
.
view
(
1
,
-
1
),
logits
)
# Append next token to decoder tokens
# Append next token to decoder tokens
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token_id
])
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token_id
])
...
...
server/text_generation/utils.py
View file @
313194f6
...
@@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
...
@@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
transformers.generation.logits_process
import
(
from
transformers.generation.logits_process
import
(
LogitsProcessorList
,
LogitsProcessorList
,
RepetitionPenaltyLogitsProcessor
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
TopPLogitsWarper
,
TopPLogitsWarper
,
TopKLogitsWarper
,
TopKLogitsWarper
,
...
@@ -48,6 +49,7 @@ class NextTokenChooser:
...
@@ -48,6 +49,7 @@ class NextTokenChooser:
def
__init__
(
def
__init__
(
self
,
self
,
temperature
=
1.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
,
top_k
=
None
,
top_k
=
None
,
top_p
=
None
,
top_p
=
None
,
do_sample
=
False
,
do_sample
=
False
,
...
@@ -68,6 +70,9 @@ class NextTokenChooser:
...
@@ -68,6 +70,9 @@ class NextTokenChooser:
if
top_p
is
not
None
and
top_p
<
1.0
:
if
top_p
is
not
None
and
top_p
<
1.0
:
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
warpers
.
append
(
TopPLogitsWarper
(
top_p
=
top_p
))
sampling
=
True
sampling
=
True
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
warpers
.
append
(
RepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
sampling
=
True
self
.
warpers
=
warpers
self
.
warpers
=
warpers
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
self
.
choice
=
Sampling
(
seed
,
device
)
if
sampling
else
Greedy
()
...
@@ -75,8 +80,10 @@ class NextTokenChooser:
...
@@ -75,8 +80,10 @@ class NextTokenChooser:
def
__call__
(
self
,
input_ids
,
scores
):
def
__call__
(
self
,
input_ids
,
scores
):
# Warp logits
# Warp logits
scores
=
self
.
warpers
(
input_ids
,
scores
)
scores
=
self
.
warpers
(
input_ids
,
scores
)
# Compute logprobs
# Compute logprobs
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
# Choose tokens
# Choose tokens
next_ids
=
self
.
choice
(
scores
)
next_ids
=
self
.
choice
(
scores
)
return
next_ids
,
logprobs
return
next_ids
,
logprobs
...
@@ -87,6 +94,7 @@ class NextTokenChooser:
...
@@ -87,6 +94,7 @@ class NextTokenChooser:
)
->
"NextTokenChooser"
:
)
->
"NextTokenChooser"
:
return
NextTokenChooser
(
return
NextTokenChooser
(
temperature
=
pb
.
temperature
,
temperature
=
pb
.
temperature
,
repetition_penalty
=
pb
.
repetition_penalty
,
top_k
=
pb
.
top_k
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
do_sample
=
pb
.
do_sample
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment