Unverified Commit cef0553d authored by drbh's avatar drbh Committed by GitHub
Browse files

Outlines guided generation (#1539)

This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
parent 4c2848b2
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; use tokenizers::TruncationDirection;
...@@ -19,6 +21,7 @@ pub struct Validation { ...@@ -19,6 +21,7 @@ pub struct Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
disable_grammar_support: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
...@@ -32,6 +35,7 @@ impl Validation { ...@@ -32,6 +35,7 @@ impl Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
disable_grammar_support: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
...@@ -66,6 +70,7 @@ impl Validation { ...@@ -66,6 +70,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
} }
} }
...@@ -182,6 +187,7 @@ impl Validation { ...@@ -182,6 +187,7 @@ impl Validation {
watermark, watermark,
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
grammar,
.. ..
} = request.parameters; } = request.parameters;
...@@ -292,6 +298,28 @@ impl Validation { ...@@ -292,6 +298,28 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
// NOTE: this is currently difficult because we need the tokenizer in Python to build
// the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
// may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if self.disable_grammar_support {
return Err(ValidationError::Grammar);
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
None => (String::new(), ProtoGrammarType::None.into()),
};
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
...@@ -302,6 +330,8 @@ impl Validation { ...@@ -302,6 +330,8 @@ impl Validation {
do_sample, do_sample,
seed, seed,
watermark, watermark,
grammar,
grammar_type,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
...@@ -453,6 +483,8 @@ pub enum ValidationError { ...@@ -453,6 +483,8 @@ pub enum ValidationError {
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
} }
#[cfg(test)] #[cfg(test)]
...@@ -470,6 +502,7 @@ mod tests { ...@@ -470,6 +502,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
...@@ -478,6 +511,7 @@ mod tests { ...@@ -478,6 +511,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
let max_new_tokens = 10; let max_new_tokens = 10;
...@@ -498,6 +532,7 @@ mod tests { ...@@ -498,6 +532,7 @@ mod tests {
let max_top_n_tokens = 4; let max_top_n_tokens = 4;
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
...@@ -507,6 +542,7 @@ mod tests { ...@@ -507,6 +542,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
let max_new_tokens = 10; let max_new_tokens = 10;
...@@ -528,6 +564,7 @@ mod tests { ...@@ -528,6 +564,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
...@@ -536,6 +573,7 @@ mod tests { ...@@ -536,6 +573,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
...@@ -562,6 +600,7 @@ mod tests { ...@@ -562,6 +600,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
...@@ -570,6 +609,7 @@ mod tests { ...@@ -570,6 +609,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
...@@ -625,6 +665,7 @@ mod tests { ...@@ -625,6 +665,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
...@@ -633,6 +674,7 @@ mod tests { ...@@ -633,6 +674,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
......
...@@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true } ...@@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines="^0.0.27"
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]
......
...@@ -87,7 +87,9 @@ class CausalLMBatch(Batch): ...@@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -413,14 +415,14 @@ class CausalLMBatch(Batch): ...@@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
...@@ -438,9 +440,9 @@ class CausalLMBatch(Batch): ...@@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values
...@@ -504,9 +506,11 @@ class CausalLM(Model): ...@@ -504,9 +506,11 @@ class CausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
if torch.cuda.is_available() and torch.cuda.device_count() > 1 "auto"
else None, if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
...@@ -696,7 +700,7 @@ class CausalLM(Model): ...@@ -696,7 +700,7 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
...@@ -735,6 +739,9 @@ class CausalLM(Model): ...@@ -735,6 +739,9 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
......
...@@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch): ...@@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
...@@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch): ...@@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer,
) )
speculative_ids = ( speculative_ids = (
...@@ -869,7 +870,11 @@ class FlashCausalLM(Model): ...@@ -869,7 +870,11 @@ class FlashCausalLM(Model):
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -1013,9 +1018,9 @@ class FlashCausalLM(Model): ...@@ -1013,9 +1018,9 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[out_start_index : out_end_index - 1] = (
out_start_index : out_end_index - 1 batch.input_ids[start_index + 1 : start_index + out_length]
] = batch.input_ids[start_index + 1 : start_index + out_length] )
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[ prefill_tokens_indices = batch.input_ids[
...@@ -1028,6 +1033,7 @@ class FlashCausalLM(Model): ...@@ -1028,6 +1033,7 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
# Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
...@@ -1166,7 +1172,7 @@ class FlashCausalLM(Model): ...@@ -1166,7 +1172,7 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
...@@ -1204,6 +1210,12 @@ class FlashCausalLM(Model): ...@@ -1204,6 +1210,12 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.input_lengths[i] > batch.max_seqlen:
......
...@@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
......
...@@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
......
...@@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -815,6 +815,9 @@ class IdeficsCausalLM(Model): ...@@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
......
...@@ -124,7 +124,7 @@ class MambaBatch(Batch): ...@@ -124,7 +124,7 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -694,6 +694,9 @@ class Mamba(Model): ...@@ -694,6 +694,9 @@ class Mamba(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
......
...@@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
...@@ -789,6 +789,9 @@ class Seq2SeqLM(Model): ...@@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.decoder_input_ids[i] = next_token_id batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length batch.input_lengths[i] = input_length
......
import math import math
import torch import torch
import json
from loguru import logger
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import ( from transformers import (
LogitsWarper, LogitsWarper,
...@@ -135,9 +144,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor): ...@@ -135,9 +144,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where( score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
...@@ -464,3 +471,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): ...@@ -464,3 +471,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
self.processors = new_processors self.processors = new_processors
return self return self
return None return None
class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_state: int,
):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
def advance(self, next_token_id, fsm_grammar_state):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsm
)
@staticmethod
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1:
return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
start_time = time.time()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for i in range(len(grammars)):
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type[i], grammars[i], self.tokenizer
)
self.fsms.append(fsm)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_states: List[int],
mask: torch.Tensor,
):
mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]):
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0
logits += mask
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
return [
GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
)
for i in range(len(next_token_ids))
]
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index]
)
def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices)
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor, FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor,
...@@ -13,6 +15,7 @@ from text_generation_server.utils.logits_process import ( ...@@ -13,6 +15,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor,
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
...@@ -22,16 +25,20 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess ...@@ -22,16 +25,20 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark: bool = False,
temperature=1.0, temperature: float = 1.0,
repetition_penalty=1.0, repetition_penalty: float = 1.0,
frequency_penalty=0.0, frequency_penalty: float = 0.0,
top_k=None, top_k: Optional[int] = None,
top_p=None, top_p: Optional[float] = None,
typical_p=None, typical_p: Optional[float] = None,
do_sample=False, do_sample: bool = False,
seed=0, seed: int = 0,
device="cpu", device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
...@@ -46,6 +53,12 @@ class NextTokenChooser: ...@@ -46,6 +53,12 @@ class NextTokenChooser:
if frequency_penalty and frequency_penalty != 0.0 if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
)
self.tokenizer = tokenizer
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
...@@ -61,7 +74,10 @@ class NextTokenChooser: ...@@ -61,7 +74,10 @@ class NextTokenChooser:
self.static_warper = None self.static_warper = None
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_processor is not None: if self.watermark_processor is not None:
...@@ -70,6 +86,8 @@ class NextTokenChooser: ...@@ -70,6 +86,8 @@ class NextTokenChooser:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores) scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
...@@ -80,11 +98,19 @@ class NextTokenChooser: ...@@ -80,11 +98,19 @@ class NextTokenChooser:
return next_id, next_logprob return next_id, next_logprob
def advance_grammar(self, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
...@@ -97,6 +123,9 @@ class NextTokenChooser: ...@@ -97,6 +123,9 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
) )
...@@ -201,6 +230,10 @@ class HeterogeneousNextTokenChooser: ...@@ -201,6 +230,10 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states=List[int],
): ):
warpers = [] warpers = []
...@@ -232,6 +265,14 @@ class HeterogeneousNextTokenChooser: ...@@ -232,6 +265,14 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
...@@ -263,6 +304,10 @@ class HeterogeneousNextTokenChooser: ...@@ -263,6 +304,10 @@ class HeterogeneousNextTokenChooser:
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
def __call__( def __call__(
self, self,
...@@ -283,6 +328,8 @@ class HeterogeneousNextTokenChooser: ...@@ -283,6 +328,8 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1) scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S): for j in range(S):
_scores = scores[:, j] _scores = scores[:, j]
if self.watermark_processor is not None: if self.watermark_processor is not None:
...@@ -291,10 +338,10 @@ class HeterogeneousNextTokenChooser: ...@@ -291,10 +338,10 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
...@@ -352,6 +399,21 @@ class HeterogeneousNextTokenChooser: ...@@ -352,6 +399,21 @@ class HeterogeneousNextTokenChooser:
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states, self.grammars
)
self.fsm_grammar_states = other_new_states
return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
)
return self
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices) self.watermark_processor = self.watermark_processor.filter(indices)
...@@ -362,6 +424,9 @@ class HeterogeneousNextTokenChooser: ...@@ -362,6 +424,9 @@ class HeterogeneousNextTokenChooser:
if self.frequency_processor is not None: if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices) self.frequency_processor = self.frequency_processor.filter(indices)
if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
...@@ -372,6 +437,18 @@ class HeterogeneousNextTokenChooser: ...@@ -372,6 +437,18 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample): if any(self.do_sample):
self.choice.filter(indices) self.choice.filter(indices)
else: else:
...@@ -385,6 +462,7 @@ class HeterogeneousNextTokenChooser: ...@@ -385,6 +462,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
...@@ -398,6 +476,10 @@ class HeterogeneousNextTokenChooser: ...@@ -398,6 +476,10 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment