Unverified Commit 211b54ac authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Rebased #617 (#868)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarVincent Brouwers <vincent.brouwers@ing.com>
parent 4486f78c
...@@ -37,6 +37,7 @@ pub(crate) async fn generation_task( ...@@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
...@@ -48,7 +49,7 @@ pub(crate) async fn generation_task( ...@@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
// End task if a message is received on shutdown_receiver // End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished // _shutdown_guard_sender will be dropped once the task is finished
tokio::select! { tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res { if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(()); run_sender.send(Err(err)).await.unwrap_or(());
} }
...@@ -64,6 +65,7 @@ async fn generate_runs( ...@@ -64,6 +65,7 @@ async fn generate_runs(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
...@@ -82,6 +84,7 @@ async fn generate_runs( ...@@ -82,6 +84,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
...@@ -97,6 +100,7 @@ async fn generate_runs( ...@@ -97,6 +100,7 @@ async fn generate_runs(
b, b,
decode_length, decode_length,
parameters.clone(), parameters.clone(),
top_n_tokens,
&mut client, &mut client,
) )
.await?; .await?;
...@@ -130,6 +134,7 @@ async fn prefill( ...@@ -130,6 +134,7 @@ async fn prefill(
batch_size: u32, batch_size: u32,
decode_length: u32, decode_length: u32,
parameters: NextTokenChooserParameters, parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient, client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> { ) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests // Create requests
...@@ -145,6 +150,7 @@ async fn prefill( ...@@ -145,6 +150,7 @@ async fn prefill(
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0),
}) })
.collect(); .collect();
......
...@@ -22,6 +22,7 @@ pub async fn run( ...@@ -22,6 +22,7 @@ pub async fn run(
batch_size: Vec<u32>, batch_size: Vec<u32>,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
...@@ -70,6 +71,7 @@ pub async fn run( ...@@ -70,6 +71,7 @@ pub async fn run(
batch_size.clone(), batch_size.clone(),
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
parameters, parameters,
...@@ -130,6 +132,7 @@ pub async fn run( ...@@ -130,6 +132,7 @@ pub async fn run(
tokenizer_name, tokenizer_name,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
n_runs, n_runs,
warmups, warmups,
temperature, temperature,
......
...@@ -93,6 +93,11 @@ struct Args { ...@@ -93,6 +93,11 @@ struct Args {
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
do_sample: bool, do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
} }
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
...@@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
top_n_tokens,
} = args; } = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
...@@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
batch_size, batch_size,
sequence_length, sequence_length,
decode_length, decode_length,
top_n_tokens,
runs, runs,
warmups, warmups,
temperature, temperature,
......
...@@ -7,6 +7,7 @@ pub(crate) fn parameters_table( ...@@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
tokenizer_name: String, tokenizer_name: String,
sequence_length: u32, sequence_length: u32,
decode_length: u32, decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize, n_runs: usize,
warmups: usize, warmups: usize,
temperature: Option<f32>, temperature: Option<f32>,
...@@ -24,6 +25,7 @@ pub(crate) fn parameters_table( ...@@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]); builder.push_record(["Temperature", &format!("{temperature:?}")]);
......
...@@ -75,6 +75,7 @@ class Client: ...@@ -75,6 +75,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
...@@ -113,6 +114,8 @@ class Client: ...@@ -113,6 +114,8 @@ class Client:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Response: generated response Response: generated response
...@@ -134,6 +137,7 @@ class Client: ...@@ -134,6 +137,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
...@@ -164,6 +168,7 @@ class Client: ...@@ -164,6 +168,7 @@ class Client:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
...@@ -198,6 +203,8 @@ class Client: ...@@ -198,6 +203,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Iterator[StreamResponse]: stream of generated tokens Iterator[StreamResponse]: stream of generated tokens
...@@ -219,6 +226,7 @@ class Client: ...@@ -219,6 +226,7 @@ class Client:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
...@@ -317,6 +325,7 @@ class AsyncClient: ...@@ -317,6 +325,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
...@@ -355,6 +364,8 @@ class AsyncClient: ...@@ -355,6 +364,8 @@ class AsyncClient:
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`): decoder_input_details (`bool`):
Return the decoder input token logprobs and ids Return the decoder input token logprobs and ids
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
Response: generated response Response: generated response
...@@ -376,6 +387,7 @@ class AsyncClient: ...@@ -376,6 +387,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
...@@ -404,6 +416,7 @@ class AsyncClient: ...@@ -404,6 +416,7 @@ class AsyncClient:
truncate: Optional[int] = None, truncate: Optional[int] = None,
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
...@@ -438,6 +451,8 @@ class AsyncClient: ...@@ -438,6 +451,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`): watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
top_n_tokens (`int`):
Return the `n` most likely tokens at each step
Returns: Returns:
AsyncIterator[StreamResponse]: stream of generated tokens AsyncIterator[StreamResponse]: stream of generated tokens
...@@ -459,6 +474,7 @@ class AsyncClient: ...@@ -459,6 +474,7 @@ class AsyncClient:
truncate=truncate, truncate=truncate,
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
......
...@@ -39,6 +39,8 @@ class Parameters(BaseModel): ...@@ -39,6 +39,8 @@ class Parameters(BaseModel):
details: bool = False details: bool = False
# Get decoder input token logprobs and ids # Get decoder input token logprobs and ids
decoder_input_details: bool = False decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
...@@ -101,6 +103,12 @@ class Parameters(BaseModel): ...@@ -101,6 +103,12 @@ class Parameters(BaseModel):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0") raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
@validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
...@@ -125,9 +133,7 @@ class Request(BaseModel): ...@@ -125,9 +133,7 @@ class Request(BaseModel):
and parameters.best_of > 1 and parameters.best_of > 1
and field_value and field_value
): ):
raise ValidationError( raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value return field_value
...@@ -179,6 +185,8 @@ class BestOfSequence(BaseModel): ...@@ -179,6 +185,8 @@ class BestOfSequence(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details # `generate` details
...@@ -193,6 +201,8 @@ class Details(BaseModel): ...@@ -193,6 +201,8 @@ class Details(BaseModel):
prefill: List[InputToken] prefill: List[InputToken]
# Generated tokens # Generated tokens
tokens: List[Token] tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter # Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]] best_of_sequences: Optional[List[BestOfSequence]]
...@@ -219,6 +229,8 @@ class StreamDetails(BaseModel): ...@@ -219,6 +229,8 @@ class StreamDetails(BaseModel):
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
# Generated token # Generated token
token: Token token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text # Complete generated text
# Only available when the generation is finished # Only available when the generation is finished
generated_text: Optional[str] generated_text: Optional[str]
......
...@@ -159,6 +159,14 @@ struct Args { ...@@ -159,6 +159,14 @@ struct Args {
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
/// This is the maximum allowed value for clients to set `top_n_tokens`.
/// `top_n_tokens is used to return information about the the `n` most likely
/// tokens at each generation step, instead of just the sampled token. This
/// information can be used for downstream tasks like for classification or
/// ranking.
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
/// This is the maximum allowed input length (expressed in number of tokens) /// This is the maximum allowed input length (expressed in number of tokens)
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
...@@ -929,6 +937,8 @@ fn spawn_webserver( ...@@ -929,6 +937,8 @@ fn spawn_webserver(
args.max_best_of.to_string(), args.max_best_of.to_string(),
"--max-stop-sequences".to_string(), "--max-stop-sequences".to_string(),
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),
args.max_input_length.to_string(), args.max_input_length.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
......
...@@ -91,6 +91,8 @@ message Request { ...@@ -91,6 +91,8 @@ message Request {
StoppingCriteriaParameters stopping_parameters = 5; StoppingCriteriaParameters stopping_parameters = 5;
/// Return prefill logprobs /// Return prefill logprobs
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens
uint32 top_n_tokens = 7;
} }
message Batch { message Batch {
...@@ -141,6 +143,17 @@ message PrefillTokens { ...@@ -141,6 +143,17 @@ message PrefillTokens {
repeated string texts = 3; repeated string texts = 3;
} }
message TopTokens {
/// Top Token IDs
repeated uint32 ids = 1;
/// Top Logprobs
repeated float logprobs = 2;
/// Top Token Texts
repeated string texts = 3;
/// If the tokens are special
repeated bool is_special = 6;
}
message Generation { message Generation {
/// Request ID /// Request ID
uint64 request_id = 1; uint64 request_id = 1;
...@@ -156,6 +169,8 @@ message Generation { ...@@ -156,6 +169,8 @@ message Generation {
bool token_is_special = 6; bool token_is_special = 6;
/// Complete generated text /// Complete generated text
optional GeneratedText generated_text = 7; optional GeneratedText generated_text = 7;
/// Top tokens
TopTokens top_tokens = 8;
} }
message FilterBatchRequest { message FilterBatchRequest {
......
...@@ -131,6 +131,7 @@ impl Client { ...@@ -131,6 +131,7 @@ impl Client {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
prefill_logprobs: true, prefill_logprobs: true,
top_n_tokens: 20,
}); });
n_tokens += max_input_length; n_tokens += max_input_length;
} }
......
...@@ -50,6 +50,7 @@ impl Health { ...@@ -50,6 +50,7 @@ impl Health {
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0,
}; };
let batch = Batch { let batch = Batch {
id: BATCH_ID, id: BATCH_ID,
......
...@@ -138,12 +138,15 @@ impl Infer { ...@@ -138,12 +138,15 @@ impl Infer {
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<InferResponse, InferError> { ) -> Result<InferResponse, InferError> {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives // Create stream and keep semaphore permit as long as generate lives
let (_permit, mut stream) = self.generate_stream(request).await?; let (_permit, mut stream) = self.generate_stream(request).await?;
// Return values // Return values
let mut result_prefill = Vec::new(); let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new(); let mut result_tokens = Vec::new();
let mut result_top_tokens = Vec::new();
let mut result_generated_text = None; let mut result_generated_text = None;
let mut result_start = None; let mut result_start = None;
let mut result_queued = None; let mut result_queued = None;
...@@ -164,7 +167,10 @@ impl Infer { ...@@ -164,7 +167,10 @@ impl Infer {
.collect(); .collect();
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Intermediate { token, top_tokens } => {
result_tokens.push(token);
result_top_tokens.push(top_tokens);
}
// Final message // Final message
// Set return values // Set return values
InferStreamResponse::End { InferStreamResponse::End {
...@@ -172,8 +178,10 @@ impl Infer { ...@@ -172,8 +178,10 @@ impl Infer {
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
result_tokens.push(token); result_tokens.push(token);
result_top_tokens.push(top_tokens);
result_generated_text = Some(generated_text); result_generated_text = Some(generated_text);
result_start = Some(start); result_start = Some(start);
result_queued = Some(queued) result_queued = Some(queued)
...@@ -191,6 +199,11 @@ impl Infer { ...@@ -191,6 +199,11 @@ impl Infer {
generated_text, generated_text,
queued, queued,
start, start,
top_tokens: if use_top_tokens {
result_top_tokens
} else {
Vec::new()
},
}) })
} else { } else {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
...@@ -520,6 +533,26 @@ fn send_responses( ...@@ -520,6 +533,26 @@ fn send_responses(
special: generation.token_is_special, special: generation.token_is_special,
}; };
// generation.top_tokens
let mut top_tokens = Vec::new();
if let Some(top_tokens_) = generation.top_tokens {
top_tokens.extend(
top_tokens_
.ids
.into_iter()
.zip(top_tokens_.logprobs.into_iter())
.zip(top_tokens_.texts.into_iter())
.zip(top_tokens_.is_special.into_iter())
.map(|(((id, logprob), text), special)| Token {
id,
text,
logprob,
special,
}),
)
}
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
// Generation has ended // Generation has ended
stopped = true; stopped = true;
...@@ -527,6 +560,7 @@ fn send_responses( ...@@ -527,6 +560,7 @@ fn send_responses(
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
top_tokens,
generated_text, generated_text,
queued: entry.queue_time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
...@@ -536,7 +570,7 @@ fn send_responses( ...@@ -536,7 +570,7 @@ fn send_responses(
} else { } else {
// Send message // Send message
entry.response_tx.send_timeout( entry.response_tx.send_timeout(
Ok(InferStreamResponse::Token(token)), Ok(InferStreamResponse::Intermediate { token, top_tokens }),
Duration::from_millis(10), Duration::from_millis(10),
)?; )?;
} }
...@@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse { ...@@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse {
// Optional first message // Optional first message
Prefill(PrefillTokens), Prefill(PrefillTokens),
// Intermediate messages // Intermediate messages
Token(Token), Intermediate {
token: Token,
top_tokens: Vec<Token>,
},
// Last message // Last message
End { End {
token: Token, token: Token,
top_tokens: Vec<Token>,
generated_text: GeneratedText, generated_text: GeneratedText,
start: Instant, start: Instant,
queued: Instant, queued: Instant,
...@@ -583,6 +621,7 @@ pub(crate) struct InferResponse { ...@@ -583,6 +621,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant, pub(crate) queued: Instant,
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) top_tokens: Vec<Vec<Token>>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
......
...@@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters { ...@@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters {
example = "null" example = "null"
)] )]
pub seed: Option<u64>, pub seed: Option<u64>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
} }
fn default_max_new_tokens() -> u32 { fn default_max_new_tokens() -> u32 {
...@@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters {
details: false, details: false,
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None,
} }
} }
...@@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence { ...@@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence {
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>, pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
...@@ -249,6 +255,8 @@ pub(crate) struct Details { ...@@ -249,6 +255,8 @@ pub(crate) struct Details {
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>, pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
...@@ -272,6 +280,8 @@ pub(crate) struct StreamDetails { ...@@ -272,6 +280,8 @@ pub(crate) struct StreamDetails {
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct StreamResponse { pub(crate) struct StreamResponse {
pub token: Token, pub token: Token,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Token>,
#[schema(nullable = true, default = "null", example = "test")] #[schema(nullable = true, default = "null", example = "test")]
pub generated_text: Option<String>, pub generated_text: Option<String>,
#[schema(nullable = true, default = "null")] #[schema(nullable = true, default = "null")]
......
...@@ -29,6 +29,8 @@ struct Args { ...@@ -29,6 +29,8 @@ struct Args {
max_best_of: usize, max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_length: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
...@@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> { ...@@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
...@@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> { ...@@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> {
max_concurrent_requests, max_concurrent_requests,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
......
...@@ -235,6 +235,7 @@ impl State { ...@@ -235,6 +235,7 @@ impl State {
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(entry.request.parameters.clone()), parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()),
top_n_tokens: entry.request.top_n_tokens,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
...@@ -328,6 +329,7 @@ mod tests { ...@@ -328,6 +329,7 @@ mod tests {
max_new_tokens: 1, max_new_tokens: 1,
stop_sequences: vec![], stop_sequences: vec![],
}, },
top_n_tokens: 0,
}, },
response_tx, response_tx,
span: info_span!("entry"), span: info_span!("entry"),
......
...@@ -158,7 +158,7 @@ async fn generate( ...@@ -158,7 +158,7 @@ async fn generate(
add_prompt = Some(req.inputs.clone()); add_prompt = Some(req.inputs.clone());
} }
let details = req.parameters.details || req.parameters.decoder_input_details; let details: bool = req.parameters.details || req.parameters.decoder_input_details;
// Inference // Inference
let (response, best_of_responses) = match req.parameters.best_of { let (response, best_of_responses) = match req.parameters.best_of {
...@@ -191,6 +191,7 @@ async fn generate( ...@@ -191,6 +191,7 @@ async fn generate(
generated_tokens: response.generated_text.generated_tokens, generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill, prefill: response.prefill,
tokens: response.tokens, tokens: response.tokens,
top_tokens: response.top_tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
} }
}) })
...@@ -204,6 +205,7 @@ async fn generate( ...@@ -204,6 +205,7 @@ async fn generate(
tokens: response.tokens, tokens: response.tokens,
seed: response.generated_text.seed, seed: response.generated_text.seed,
best_of_sequences, best_of_sequences,
top_tokens: response.top_tokens,
}) })
} }
false => None, false => None,
...@@ -385,12 +387,16 @@ async fn generate_stream( ...@@ -385,12 +387,16 @@ async fn generate_stream(
// Prefill is ignored // Prefill is ignored
InferStreamResponse::Prefill(_) => {} InferStreamResponse::Prefill(_) => {}
// Yield event for every new token // Yield event for every new token
InferStreamResponse::Token(token) => { InferStreamResponse::Intermediate{
token,
top_tokens,
} => {
tracing::debug!(parent: &span, "Token: {:?}", token); tracing::debug!(parent: &span, "Token: {:?}", token);
// StreamResponse // StreamResponse
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: None, generated_text: None,
details: None, details: None,
}; };
...@@ -403,6 +409,7 @@ async fn generate_stream( ...@@ -403,6 +409,7 @@ async fn generate_stream(
generated_text, generated_text,
start, start,
queued, queued,
top_tokens,
} => { } => {
// Token details // Token details
let details = match details { let details = match details {
...@@ -451,6 +458,7 @@ async fn generate_stream( ...@@ -451,6 +458,7 @@ async fn generate_stream(
let stream_token = StreamResponse { let stream_token = StreamResponse {
token, token,
top_tokens: top_tokens,
generated_text: Some(output_text), generated_text: Some(output_text),
details details
}; };
...@@ -509,6 +517,7 @@ pub async fn run( ...@@ -509,6 +517,7 @@ pub async fn run(
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
waiting_served_ratio: f32, waiting_served_ratio: f32,
...@@ -571,6 +580,7 @@ pub async fn run( ...@@ -571,6 +580,7 @@ pub async fn run(
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
......
...@@ -15,6 +15,7 @@ pub struct Validation { ...@@ -15,6 +15,7 @@ pub struct Validation {
/// Validation parameters /// Validation parameters
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
...@@ -27,6 +28,7 @@ impl Validation { ...@@ -27,6 +28,7 @@ impl Validation {
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
max_best_of: usize, max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
) -> Self { ) -> Self {
...@@ -54,6 +56,7 @@ impl Validation { ...@@ -54,6 +56,7 @@ impl Validation {
max_best_of, max_best_of,
sender, sender,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
} }
...@@ -142,6 +145,7 @@ impl Validation { ...@@ -142,6 +145,7 @@ impl Validation {
seed, seed,
watermark, watermark,
decoder_input_details, decoder_input_details,
top_n_tokens,
.. ..
} = request.parameters; } = request.parameters;
...@@ -218,6 +222,15 @@ impl Validation { ...@@ -218,6 +222,15 @@ impl Validation {
} }
}; };
let top_n_tokens = top_n_tokens
.map(|value| {
if value > self.max_top_n_tokens {
return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));
}
Ok(value)
})
.unwrap_or(Ok(0))?;
// Check if inputs is empty // Check if inputs is empty
if request.inputs.is_empty() { if request.inputs.is_empty() {
return Err(EmptyInput); return Err(EmptyInput);
...@@ -263,6 +276,7 @@ impl Validation { ...@@ -263,6 +276,7 @@ impl Validation {
truncate: truncate.unwrap_or(self.max_input_length) as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32,
parameters, parameters,
stopping_parameters, stopping_parameters,
top_n_tokens: top_n_tokens,
}) })
} }
...@@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest { ...@@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest {
pub decoder_input_details: bool, pub decoder_input_details: bool,
pub parameters: NextTokenChooserParameters, pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters, pub stopping_parameters: StoppingCriteriaParameters,
pub top_n_tokens: u32,
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
...@@ -350,6 +365,10 @@ pub enum ValidationError { ...@@ -350,6 +365,10 @@ pub enum ValidationError {
BestOfSeed, BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")] #[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream, BestOfStream,
#[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")]
TopNTokens(u32, u32),
#[error("`top_n_tokens` != 0 is not allowed for this endpoint")]
TopNTokensDisabled,
#[error("`decoder_input_details` == true is not supported when streaming tokens")] #[error("`decoder_input_details` == true is not supported when streaming tokens")]
PrefillDetailsStream, PrefillDetailsStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
...@@ -391,14 +410,16 @@ mod tests { ...@@ -391,14 +410,16 @@ mod tests {
let tokenizer = None; let tokenizer = None;
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
...@@ -418,14 +439,16 @@ mod tests { ...@@ -418,14 +439,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
...@@ -435,7 +458,7 @@ mod tests { ...@@ -435,7 +458,7 @@ mod tests {
.validate_input("Hello".to_string(), None, max_new_tokens) .validate_input("Hello".to_string(), None, max_new_tokens)
.await .await
{ {
Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
_ => panic!("Unexpected not max new tokens"), _ => panic!("Unexpected not max new tokens"),
} }
} }
...@@ -445,14 +468,16 @@ mod tests { ...@@ -445,14 +468,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
...@@ -477,14 +502,16 @@ mod tests { ...@@ -477,14 +502,16 @@ mod tests {
let tokenizer = Some(get_tokenizer().await); let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2; let max_best_of = 2;
let max_stop_sequence = 3; let max_stop_sequence = 3;
let max_input_length = 4; let max_top_n_tokens = 4;
let max_total_tokens = 5; let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
max_best_of, max_best_of,
max_stop_sequence, max_stop_sequence,
max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
); );
...@@ -531,4 +558,75 @@ mod tests { ...@@ -531,4 +558,75 @@ mod tests {
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value. // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
assert_eq!(valid_request.parameters.top_p, 1.0); assert_eq!(valid_request.parameters.top_p, 1.0);
} }
#[tokio::test]
async fn test_validation_top_n_tokens() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let workers = 1;
let validation = Validation::new(
workers,
tokenizer,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_length,
max_total_tokens,
);
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
..default_parameters()
},
})
.await
{
Err(ValidationError::TopNTokens(4, 5)) => (),
_ => panic!("Unexpected top_n_tokens"),
}
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: 1,
..default_parameters()
},
})
.await
.unwrap();
assert_eq!(valid_request.top_n_tokens, 0);
}
} }
import torch
from text_generation_server.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,
batch_top_tokens,
) )
...@@ -42,3 +44,22 @@ def test_stopping_criteria_max(): ...@@ -42,3 +44,22 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]
assert topn_tok_ids[2] == [0, 3, 1, 4]
assert topn_tok_ids[3] == [0, 3, 1, 4]
assert topn_tok_ids[4] == [0, 3, 1, 4, 2]
assert topn_tok_logprobs[0] == []
assert topn_tok_logprobs[1] == [-1, -2]
assert topn_tok_logprobs[2] == [-1, -2, -3, -3]
assert topn_tok_logprobs[3] == [-1, -2, -3, -3]
assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4]
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import inspect import inspect
...@@ -12,6 +13,7 @@ from text_generation_server.models.types import ( ...@@ -12,6 +13,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
...@@ -42,6 +44,8 @@ class CausalLMBatch(Batch): ...@@ -42,6 +44,8 @@ class CausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Metadata used for padding # Metadata used for padding
max_input_length: int max_input_length: int
...@@ -72,6 +76,7 @@ class CausalLMBatch(Batch): ...@@ -72,6 +76,7 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
prefix_offsets = [] prefix_offsets = []
read_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
...@@ -88,6 +93,7 @@ class CausalLMBatch(Batch): ...@@ -88,6 +93,7 @@ class CausalLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
max_decode_tokens += stopping_criteria.max_new_tokens max_decode_tokens += stopping_criteria.max_new_tokens
padding_right_offset = max( padding_right_offset = max(
...@@ -121,6 +127,9 @@ class CausalLMBatch(Batch): ...@@ -121,6 +127,9 @@ class CausalLMBatch(Batch):
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
max_tokens = len(inputs) * (max_input_length + max_decode_tokens) max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
...@@ -138,6 +147,8 @@ class CausalLMBatch(Batch): ...@@ -138,6 +147,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -163,6 +174,7 @@ class CausalLMBatch(Batch): ...@@ -163,6 +174,7 @@ class CausalLMBatch(Batch):
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
total_remaining_decode_tokens = 0 total_remaining_decode_tokens = 0
new_padding_right_offset = 0 new_padding_right_offset = 0
...@@ -184,6 +196,7 @@ class CausalLMBatch(Batch): ...@@ -184,6 +196,7 @@ class CausalLMBatch(Batch):
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_decode_tokens = ( remaining_decode_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
...@@ -223,6 +236,7 @@ class CausalLMBatch(Batch): ...@@ -223,6 +236,7 @@ class CausalLMBatch(Batch):
layer[1] = past_values[keep_indices, :, -past_kv_length:, :] layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values del past_values
top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
self.requests = requests self.requests = requests
...@@ -235,6 +249,8 @@ class CausalLMBatch(Batch): ...@@ -235,6 +249,8 @@ class CausalLMBatch(Batch):
self.read_offsets = read_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.top_n_tokens = top_n_tokens
self.top_n_tokens_tensor = top_n_tokens_tensor
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
...@@ -262,6 +278,7 @@ class CausalLMBatch(Batch): ...@@ -262,6 +278,7 @@ class CausalLMBatch(Batch):
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
max_tokens = 0 max_tokens = 0
# Batch tensors # Batch tensors
...@@ -269,6 +286,7 @@ class CausalLMBatch(Batch): ...@@ -269,6 +286,7 @@ class CausalLMBatch(Batch):
attention_mask = None attention_mask = None
position_ids = None position_ids = None
past_key_values = [] past_key_values = []
top_n_tokens_tensor = None
# Used for slicing correctly inside the tensors # Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes # Equivalent to a cumsum on batch sizes
...@@ -281,6 +299,7 @@ class CausalLMBatch(Batch): ...@@ -281,6 +299,7 @@ class CausalLMBatch(Batch):
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
if i == 0: if i == 0:
requests_idx_mapping = batch.requests_idx_mapping requests_idx_mapping = batch.requests_idx_mapping
...@@ -310,6 +329,12 @@ class CausalLMBatch(Batch): ...@@ -310,6 +329,12 @@ class CausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
if top_n_tokens_tensor is None:
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
# We need to slice the attention mask to remove padding from previous steps # We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_input_length - batch.max_input_length left_offset = max_input_length - batch.max_input_length
...@@ -438,6 +463,8 @@ class CausalLMBatch(Batch): ...@@ -438,6 +463,8 @@ class CausalLMBatch(Batch):
read_offsets=read_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
max_input_length=max_input_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
...@@ -549,6 +576,12 @@ class CausalLM(Model): ...@@ -549,6 +576,12 @@ class CausalLM(Model):
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens,
batch.top_n_tokens_tensor,
torch.softmax(logits[:, -1], -1),
)
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
...@@ -559,6 +592,9 @@ class CausalLM(Model): ...@@ -559,6 +592,9 @@ class CausalLM(Model):
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.top_n_tokens,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
...@@ -571,6 +607,9 @@ class CausalLM(Model): ...@@ -571,6 +607,9 @@ class CausalLM(Model):
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
top_n_tokens,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
...@@ -637,6 +676,24 @@ class CausalLM(Model): ...@@ -637,6 +676,24 @@ class CausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
...@@ -645,6 +702,7 @@ class CausalLM(Model): ...@@ -645,6 +702,7 @@ class CausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)
......
import math import math
import itertools import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch import torch
import torch.distributed import torch.distributed
...@@ -16,6 +17,7 @@ from text_generation_server.models.types import ( ...@@ -16,6 +17,7 @@ from text_generation_server.models.types import (
PrefillTokens, PrefillTokens,
Generation, Generation,
GeneratedText, GeneratedText,
TopTokens,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
...@@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch): ...@@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch):
# Generation helpers # Generation helpers
next_token_chooser: HeterogeneousNextTokenChooser next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria] stopping_criterias: List[StoppingCriteria]
top_n_tokens: List[int]
top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch # Number of blocks in this batch
blocks: int blocks: int
...@@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch): ...@@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_length = 0 cumulative_length = 0
...@@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch): ...@@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch):
) )
max_new_tokens = stopping_criteria.max_new_tokens max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention # Paged attention
# Remove one as the first token des not have a past # Remove one as the first token des not have a past
...@@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch): ...@@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch):
prefill_next_token_indices = torch.tensor( prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device prefill_next_token_indices, dtype=torch.int64, device=device
) )
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
...@@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch): ...@@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
...@@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch): ...@@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
blocks = 0 blocks = 0
max_blocks = 0 max_blocks = 0
...@@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch): ...@@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx] stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
top_n_tokens.append(self.top_n_tokens[idx])
remaining_tokens = ( remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
) )
...@@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch): ...@@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch):
input_lengths_tensor = self.input_lengths_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
...@@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch): ...@@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
...@@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch): ...@@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length) (total_batch_size, max_length)
) )
top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(
total_batch_size,
)
start_slots = [] start_slots = []
block_tables = [] block_tables = []
...@@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch): ...@@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters = [] next_token_chooser_parameters = []
stopping_criterias = [] stopping_criterias = []
top_n_tokens = []
# Cumulative length # Cumulative length
cumulative_batch_size = 0 cumulative_batch_size = 0
...@@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch): ...@@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch):
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
all_input_ids_tensor[ all_input_ids_tensor[
...@@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch): ...@@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
top_n_tokens.extend(batch.top_n_tokens)
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
...@@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch): ...@@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
) )
...@@ -831,10 +857,14 @@ class FlashCausalLM(Model): ...@@ -831,10 +857,14 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser( next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
if prefill: if prefill:
if len(batch) > 1 and prefill_logprobs: if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
...@@ -931,8 +961,11 @@ class FlashCausalLM(Model): ...@@ -931,8 +961,11 @@ class FlashCausalLM(Model):
batch.all_input_ids, batch.all_input_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids, next_token_ids,
next_token_logprobs, next_token_logprobs,
batch_top_token_ids,
batch_top_token_logprobs,
) )
# For each member of the batch # For each member of the batch
...@@ -945,8 +978,11 @@ class FlashCausalLM(Model): ...@@ -945,8 +978,11 @@ class FlashCausalLM(Model):
all_input_ids, all_input_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens,
next_token_id, next_token_id,
next_token_logprob, next_token_logprob,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator): ) in enumerate(iterator):
# Append next token to all tokens # Append next token to all tokens
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
...@@ -1005,6 +1041,24 @@ class FlashCausalLM(Model): ...@@ -1005,6 +1041,24 @@ class FlashCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
if top_n_tokens > 0:
toptoken_texts = self.tokenizer.batch_decode(
top_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids
]
top_tokens = TopTokens(
top_token_ids,
top_token_logprobs,
toptoken_texts,
special_toptokens,
)
else:
top_tokens = None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
...@@ -1013,6 +1067,7 @@ class FlashCausalLM(Model): ...@@ -1013,6 +1067,7 @@ class FlashCausalLM(Model):
next_token_text, next_token_text,
next_token_id in self.all_special_ids, next_token_id in self.all_special_ids,
generated_text, generated_text,
top_tokens,
) )
generations.append(generation) generations.append(generation)
......
...@@ -763,6 +763,8 @@ class IdeficsCausalLM(Model): ...@@ -763,6 +763,8 @@ class IdeficsCausalLM(Model):
else: else:
prefill_tokens = None prefill_tokens = None
top_tokens=None
generation = Generation( generation = Generation(
request.id, request.id,
prefill_tokens, prefill_tokens,
...@@ -771,6 +773,7 @@ class IdeficsCausalLM(Model): ...@@ -771,6 +773,7 @@ class IdeficsCausalLM(Model):
next_token_text, next_token_text,
next_token_id_squeezed.item() in self.all_special_ids, next_token_id_squeezed.item() in self.all_special_ids,
generated_text, generated_text,
top_tokens
) )
generations.append(generation) generations.append(generation)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment