Unverified Commit 313194f6 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): support repetition penalty (#47)

parent 2ad895a6
...@@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. ...@@ -15,7 +15,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
- [Safetensors](https://github.com/huggingface/safetensors) weight loading - [Safetensors](https://github.com/huggingface/safetensors) weight loading
- 45ms per token generation for BLOOM with 8xA100 80GB - 45ms per token generation for BLOOM with 8xA100 80GB
- Logits warpers (temperature scaling, topk ...) - Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities
......
...@@ -38,6 +38,8 @@ message NextTokenChooserParameters { ...@@ -38,6 +38,8 @@ message NextTokenChooserParameters {
bool do_sample = 4; bool do_sample = 4;
/// random seed for sampling /// random seed for sampling
uint64 seed = 5; uint64 seed = 5;
/// repetition penalty
float repetition_penalty = 6;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
......
...@@ -13,6 +13,8 @@ use validation::Validation; ...@@ -13,6 +13,8 @@ use validation::Validation;
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")] #[serde(default = "default_temperature")]
pub temperature: f32, pub temperature: f32,
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
#[serde(default = "default_top_k")] #[serde(default = "default_top_k")]
pub top_k: i32, pub top_k: i32,
#[serde(default = "default_top_p")] #[serde(default = "default_top_p")]
...@@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters { ...@@ -32,6 +34,9 @@ pub(crate) struct GenerateParameters {
fn default_temperature() -> f32 { fn default_temperature() -> f32 {
1.0 1.0
} }
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_top_k() -> i32 { fn default_top_k() -> i32 {
0 0
...@@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 { ...@@ -52,6 +57,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
temperature: default_temperature(), temperature: default_temperature(),
repetition_penalty: default_repetition_penalty(),
top_k: default_top_k(), top_k: default_top_k(),
top_p: default_top_p(), top_p: default_top_p(),
do_sample: default_do_sample(), do_sample: default_do_sample(),
......
...@@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe ...@@ -33,6 +33,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
temperature: 1.0, temperature: 1.0,
repetition_penalty: 1.0,
top_k: 0, top_k: 0,
top_p: 1.0, top_p: 1.0,
do_sample: false, do_sample: false,
......
...@@ -113,6 +113,9 @@ fn validate( ...@@ -113,6 +113,9 @@ fn validate(
if request.parameters.temperature <= 0.0 { if request.parameters.temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
} }
if request.parameters.repetition_penalty <= 0.0 {
return Err(ValidationError::RepetitionPenalty);
}
if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 { if request.parameters.top_p <= 0.0 || request.parameters.top_p > 1.0 {
return Err(ValidationError::TopP); return Err(ValidationError::TopP);
} }
...@@ -146,6 +149,7 @@ fn validate( ...@@ -146,6 +149,7 @@ fn validate(
// Return ValidGenerateRequest // Return ValidGenerateRequest
let GenerateParameters { let GenerateParameters {
temperature, temperature,
repetition_penalty,
top_k, top_k,
top_p, top_p,
do_sample, do_sample,
...@@ -156,6 +160,7 @@ fn validate( ...@@ -156,6 +160,7 @@ fn validate(
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty,
top_k: top_k as u32, top_k: top_k as u32,
top_p, top_p,
do_sample, do_sample,
...@@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest { ...@@ -195,6 +200,8 @@ pub(crate) struct ValidGenerateRequest {
pub enum ValidationError { pub enum ValidationError {
#[error("temperature must be strictly positive")] #[error("temperature must be strictly positive")]
Temperature, Temperature,
#[error("repetition_penalty must be strictly positive")]
RepetitionPenalty,
#[error("top_p must be > 0.0 and <= 1.0")] #[error("top_p must be > 0.0 and <= 1.0")]
TopP, TopP,
#[error("top_k must be strictly positive")] #[error("top_k must be strictly positive")]
......
...@@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2 ...@@ -7,6 +7,7 @@ from text_generation.pb import generate_pb2
def default_pb_parameters(): def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters( return generate_pb2.NextTokenChooserParameters(
temperature=1.0, temperature=1.0,
repetition_penalty=1.0,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
do_sample=False, do_sample=False,
......
...@@ -336,7 +336,7 @@ class CausalLM(Model): ...@@ -336,7 +336,7 @@ class CausalLM(Model):
all_input_ids, all_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
tokens, logprobs = next_token_chooser(all_input_ids, logits) tokens, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits)
next_token_id = tokens[-1].view(1, 1) next_token_id = tokens[-1].view(1, 1)
# Append next token to all tokens # Append next token to all tokens
......
...@@ -418,7 +418,9 @@ class Seq2SeqLM(Model): ...@@ -418,7 +418,9 @@ class Seq2SeqLM(Model):
decoder_input_ids, decoder_input_ids,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser(decoder_input_ids, logits) next_token_id, logprobs = next_token_chooser(
decoder_input_ids.view(1, -1), logits
)
# Append next token to decoder tokens # Append next token to decoder tokens
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id]) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id])
......
...@@ -17,6 +17,7 @@ from typing import List, Optional, Tuple ...@@ -17,6 +17,7 @@ from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.generation.logits_process import ( from transformers.generation.logits_process import (
LogitsProcessorList, LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
...@@ -48,6 +49,7 @@ class NextTokenChooser: ...@@ -48,6 +49,7 @@ class NextTokenChooser:
def __init__( def __init__(
self, self,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
do_sample=False, do_sample=False,
...@@ -68,6 +70,9 @@ class NextTokenChooser: ...@@ -68,6 +70,9 @@ class NextTokenChooser:
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p)) warpers.append(TopPLogitsWarper(top_p=top_p))
sampling = True sampling = True
if repetition_penalty is not None and repetition_penalty != 1.0:
warpers.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
sampling = True
self.warpers = warpers self.warpers = warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
...@@ -75,8 +80,10 @@ class NextTokenChooser: ...@@ -75,8 +80,10 @@ class NextTokenChooser:
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
scores = self.warpers(input_ids, scores) scores = self.warpers(input_ids, scores)
# Compute logprobs # Compute logprobs
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
# Choose tokens # Choose tokens
next_ids = self.choice(scores) next_ids = self.choice(scores)
return next_ids, logprobs return next_ids, logprobs
...@@ -87,6 +94,7 @@ class NextTokenChooser: ...@@ -87,6 +94,7 @@ class NextTokenChooser:
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment