Unverified Commit 09b7c26b authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): add frequency penalty (#1541)

parent 39af000c
...@@ -2787,7 +2787,7 @@ dependencies = [ ...@@ -2787,7 +2787,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.14.1",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
...@@ -2850,7 +2850,7 @@ dependencies = [ ...@@ -2850,7 +2850,7 @@ dependencies = [
"serde_json", "serde_json",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers", "tokenizers 0.15.1",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
...@@ -2972,6 +2972,40 @@ dependencies = [ ...@@ -2972,6 +2972,40 @@ dependencies = [
"unicode_categories", "unicode_categories",
] ]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.35.1" version = "1.35.1"
......
...@@ -30,6 +30,7 @@ pub async fn run( ...@@ -30,6 +30,7 @@ pub async fn run(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
client: ShardedClient, client: ShardedClient,
...@@ -42,6 +43,7 @@ pub async fn run( ...@@ -42,6 +43,7 @@ pub async fn run(
do_sample, do_sample,
seed: 0, seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
}; };
...@@ -140,6 +142,7 @@ pub async fn run( ...@@ -140,6 +142,7 @@ pub async fn run(
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
); );
......
...@@ -84,6 +84,11 @@ struct Args { ...@@ -84,6 +84,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
frequency_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular /// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server` /// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)] #[clap(long, env)]
...@@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
master_shard_uds_path, master_shard_uds_path,
...@@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p, top_p,
typical_p, typical_p,
repetition_penalty, repetition_penalty,
frequency_penalty,
watermark, watermark,
do_sample, do_sample,
sharded_client, sharded_client,
......
...@@ -15,6 +15,7 @@ pub(crate) fn parameters_table( ...@@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
top_p: Option<f32>, top_p: Option<f32>,
typical_p: Option<f32>, typical_p: Option<f32>,
repetition_penalty: Option<f32>, repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool, watermark: bool,
do_sample: bool, do_sample: bool,
) -> Table { ) -> Table {
...@@ -33,6 +34,7 @@ pub(crate) fn parameters_table( ...@@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder.push_record(["Top P", &format!("{top_p:?}")]); builder.push_record(["Top P", &format!("{top_p:?}")]);
builder.push_record(["Typical P", &format!("{typical_p:?}")]); builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]); builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]); builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]); builder.push_record(["Do Sample", &do_sample.to_string()]);
......
...@@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): ...@@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert response.generated_text == "\n\nDeep learning is a new type of machine" assert response.generated_text == "\n\nDeep learning is a new type of machine"
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
...@@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): ...@@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert (
response.generated_text
== "blue, red, yellow, \nand orange (in the order they appear in"
)
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
......
...@@ -66,6 +66,8 @@ message NextTokenChooserParameters { ...@@ -66,6 +66,8 @@ message NextTokenChooserParameters {
uint64 seed = 6; uint64 seed = 6;
/// repetition penalty /// repetition penalty
float repetition_penalty = 7; float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
} }
......
...@@ -125,6 +125,7 @@ impl Client { ...@@ -125,6 +125,7 @@ impl Client {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true, watermark: true,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
......
...@@ -43,6 +43,7 @@ impl Health { ...@@ -43,6 +43,7 @@ impl Health {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
......
...@@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters { ...@@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
)] )]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = -2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub frequency_penalty: Option<f32>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>, pub top_k: Option<i32>,
#[serde(default)] #[serde(default)]
...@@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
best_of: None, best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
frequency_penalty: None,
top_k: None, top_k: None,
top_p: None, top_p: None,
typical_p: None, typical_p: None,
...@@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion { ...@@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: Message,
pub logprobs: Option<Vec<f32>>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String, pub finish_reason: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprobs {
content: Vec<ChatCompletionLogprob>,
}
impl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {
fn from(value: (Token, Vec<Token>)) -> Self {
let (token, top_tokens) = value;
Self {
content: vec![ChatCompletionLogprob {
token: token.text,
logprob: token.logprob,
top_logprobs: top_tokens
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
}],
}
}
}
impl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {
fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {
let (tokens, top_tokens) = value;
Self {
content: tokens
.into_iter()
.zip(top_tokens)
.map(|(t, top_t)| ChatCompletionLogprob {
token: t.text,
logprob: t.logprob,
top_logprobs: top_t
.into_iter()
.map(|t| ChatCompletionTopLogprob {
token: t.text,
logprob: t.logprob,
})
.collect(),
})
.collect(),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionLogprob {
token: String,
logprob: f32,
top_logprobs: Vec<ChatCompletionTopLogprob>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionTopLogprob {
token: String,
logprob: f32,
}
#[derive(Clone, Deserialize, Serialize)] #[derive(Clone, Deserialize, Serialize)]
pub(crate) struct Usage { pub(crate) struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
...@@ -238,7 +308,7 @@ impl ChatCompletion { ...@@ -238,7 +308,7 @@ impl ChatCompletion {
content: output, content: output,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| details.tokens.iter().map(|t| t.logprob).collect()), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
}], }],
usage: Usage { usage: Usage {
...@@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk { ...@@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionChoice {
pub index: u32, pub index: u32,
pub delta: ChatCompletionDelta, pub delta: ChatCompletionDelta,
pub logprobs: Option<f32>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
...@@ -285,7 +355,7 @@ impl ChatCompletionChunk { ...@@ -285,7 +355,7 @@ impl ChatCompletionChunk {
delta: String, delta: String,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<f32>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
Self { Self {
...@@ -319,8 +389,8 @@ pub(crate) struct ChatRequest { ...@@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
/// UNUSED /// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")] #[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API. /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String, /* NOTE: UNUSED */ pub model: String,
/* NOTE: UNUSED */
/// A list of messages comprising the conversation so far. /// A list of messages comprising the conversation so far.
#[serde(default = "default_request_messages")] #[serde(default = "default_request_messages")]
pub messages: Vec<Message>, pub messages: Vec<Message>,
...@@ -346,7 +416,6 @@ pub(crate) struct ChatRequest { ...@@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
#[schema(example = "false")] #[schema(example = "false")]
pub logprobs: Option<bool>, pub logprobs: Option<bool>,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used. /// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)] #[serde(default)]
...@@ -365,7 +434,6 @@ pub(crate) struct ChatRequest { ...@@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable = true, example = "2")] #[schema(nullable = true, example = "2")]
pub n: Option<u32>, pub n: Option<u32>,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics /// increasing the model's likelihood to talk about new topics
#[serde(default)] #[serde(default)]
...@@ -447,7 +515,7 @@ pub struct PrefillToken { ...@@ -447,7 +515,7 @@ pub struct PrefillToken {
logprob: f32, logprob: f32,
} }
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
id: u32, id: u32,
......
...@@ -355,6 +355,7 @@ mod tests { ...@@ -355,6 +355,7 @@ mod tests {
do_sample: false, do_sample: false,
seed: 0, seed: 0,
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false, watermark: false,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
......
...@@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse}; ...@@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{ use crate::{
BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta, BestOfSequence, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,
ChatRequest, CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, ChatCompletionLogprobs, ChatRequest, CompatGenerateRequest, Details, ErrorResponse,
GenerateRequest, GenerateResponse, HubModelInfo, HubTokenizerConfig, Infer, Info, Message, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo,
PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Validation, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
...@@ -570,8 +571,8 @@ async fn chat_completions( ...@@ -570,8 +571,8 @@ async fn chat_completions(
let stream = req.stream; let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100)); let max_new_tokens = req.max_tokens.or(Some(100));
let repetition_penalty = req let repetition_penalty = req
.frequency_penalty .presence_penalty
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0) // rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map(|x| x + 2.0); .map(|x| x + 2.0);
let logprobs = req.logprobs.unwrap_or(false); let logprobs = req.logprobs.unwrap_or(false);
let seed = req.seed; let seed = req.seed;
...@@ -599,6 +600,7 @@ async fn chat_completions( ...@@ -599,6 +600,7 @@ async fn chat_completions(
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
repetition_penalty, repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
...@@ -630,6 +632,10 @@ async fn chat_completions( ...@@ -630,6 +632,10 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let logprobs = logprobs.then(|| {
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
});
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
...@@ -637,7 +643,7 @@ async fn chat_completions( ...@@ -637,7 +643,7 @@ async fn chat_completions(
stream_token.token.text, stream_token.token.text,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs.then_some(stream_token.token.logprob), logprobs,
stream_token.details.map(|d| d.finish_reason.to_string()), stream_token.details.map(|d| d.finish_reason.to_string()),
)) ))
.map_or_else( .map_or_else(
......
...@@ -170,6 +170,7 @@ impl Validation { ...@@ -170,6 +170,7 @@ impl Validation {
best_of, best_of,
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
...@@ -206,6 +207,11 @@ impl Validation { ...@@ -206,6 +207,11 @@ impl Validation {
return Err(ValidationError::RepetitionPenalty); return Err(ValidationError::RepetitionPenalty);
} }
let frequency_penalty = frequency_penalty.unwrap_or(0.0);
if !(-2.0..=2.0).contains(&frequency_penalty) {
return Err(ValidationError::FrequencyPenalty);
}
// Different because the proto default value is not a valid value // Different because the proto default value is not a valid value
// for the user // for the user
let top_p = top_p let top_p = top_p
...@@ -289,6 +295,7 @@ impl Validation { ...@@ -289,6 +295,7 @@ impl Validation {
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
frequency_penalty,
top_k, top_k,
top_p, top_p,
typical_p, typical_p,
...@@ -420,6 +427,8 @@ pub enum ValidationError { ...@@ -420,6 +427,8 @@ pub enum ValidationError {
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]
RepetitionPenalty, RepetitionPenalty,
#[error("`frequency_penalty` must be >= -2.0 and <= 2.0")]
FrequencyPenalty,
#[error("`top_p` must be > 0.0 and < 1.0")] #[error("`top_p` must be > 0.0 and < 1.0")]
TopP, TopP,
#[error("`top_k` must be strictly positive")] #[error("`top_k` must be strictly positive")]
......
...@@ -70,7 +70,7 @@ def test_batch_top_tokens(): ...@@ -70,7 +70,7 @@ def test_batch_top_tokens():
# Now let's make second member of the batch be speculated # Now let's make second member of the batch be speculated
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2) inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)
accepted_ids[1] = 2 accepted_ids[1] = 2
topn_tok_ids, topn_tok_logprobs = batch_top_tokens( topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids
) )
......
...@@ -86,6 +86,7 @@ except ImportError as e: ...@@ -86,6 +86,7 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
......
...@@ -696,14 +696,17 @@ class CausalLM(Model): ...@@ -696,14 +696,17 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
......
...@@ -19,6 +19,7 @@ from einops import rearrange ...@@ -19,6 +19,7 @@ from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
...@@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig): ...@@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class MambaBlock(nn.Module): class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
...@@ -60,10 +62,14 @@ class MambaBlock(nn.Module): ...@@ -60,10 +62,14 @@ class MambaBlock(nn.Module):
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) self.dt_proj_no_bias = FastLinear.load(
self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) config, f"{prefix}.dt_proj", weights, bias=False
)
self.out_proj = FastLinear.load(
config, f"{prefix}.out_proj", weights, bias=False
)
self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True)
self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float())
self.D = weights.get_tensor(f"{prefix}.D") self.D = weights.get_tensor(f"{prefix}.D")
self.activation = "silu" self.activation = "silu"
self.dt_rank = config.dt_rank self.dt_rank = config.dt_rank
...@@ -80,12 +86,14 @@ class MambaBlock(nn.Module): ...@@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state return out, conv_state, ssm_state
projected_states = self.in_proj(hidden_states).transpose(1,2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
x, z = projected_states.chunk(2, dim=1) x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0)) conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn( x = causal_conv1d_fn(
x=x, x=x,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
) )
...@@ -94,7 +102,9 @@ class MambaBlock(nn.Module): ...@@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
# We want dt to have d as the slowest moving dimension # We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = self.dt_proj.weight @ dt.t() dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
...@@ -118,28 +128,39 @@ class MambaBlock(nn.Module): ...@@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
_xz = self.in_proj(hidden_states) _xz = self.in_proj(hidden_states)
_x, _z = _xz.chunk(2, dim=-1) # (B D) _x, _z = _xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1)
conv_out = causal_conv1d_fn( conv_out = causal_conv1d_fn(
x=conv_state_new, x=conv_state_new,
weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), weight=self.conv1d.weight.view(
bias=self.conv1d.bias, self.conv1d.weight.size(0), self.conv1d.weight.size(2)
activation=self.activation ),
bias=self.conv1d.bias,
activation=self.activation,
) )
conv_state = conv_state_new[:, :, 1:] conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros( output_tensor = torch.zeros(
(bsz, seqlen, dim), (bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
device=hidden_states.device,
dtype=hidden_states.dtype
) )
for i in range(0, bsz): for i in range(0, bsz):
x = conv_out[i:i+1,:,-1] x = conv_out[i : i + 1, :, -1]
z = _z[i:i+1, -1, :] z = _z[i : i + 1, -1, :]
x_db = self.x_proj(x) x_db = self.x_proj(x)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt, B, C = torch.split(
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
y = selective_state_update( y = selective_state_update(
ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ssm_state[i : i + 1, :, :],
x,
dt,
self.negA,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
output_tensor[i] = out output_tensor[i] = out
...@@ -147,48 +168,70 @@ class MambaBlock(nn.Module): ...@@ -147,48 +168,70 @@ class MambaBlock(nn.Module):
return output_tensor, conv_state, ssm_state return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) self.mamba_block = MambaBlock(
self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) prefix=f"{layer_id}.mixer", config=config, weights=weights
)
self.layer_norm = FastRMSNorm.load(
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
): ):
residual = (hidden_states + residual) if residual is not None else hidden_states residual = (hidden_states + residual) if residual is not None else hidden_states
shape = residual.shape shape = residual.shape
hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))
hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) hidden_states, conv_state, last_ssm_state = self.mamba_block(
hidden_states.view(*shape), inference_params
)
return hidden_states, residual, conv_state, last_ssm_state return hidden_states, residual, conv_state, last_ssm_state
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
prefix = "backbone" prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] [
ResidualBlock(f"{prefix}.layers.{i}", config, weights)
for i in range(config.n_layer)
]
)
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = FastLinear.load(
config, f"{prefix}.embedding", weights, bias=False
) )
self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon)
self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False)
self.config = config self.config = config
def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for block in self.blocks: for block in self.blocks:
hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) hidden_states, residual, conv_state, ssm_state = block(
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) hidden_states, residual, inference_params
)
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (
conv_state,
ssm_state,
)
hidden_states = hidden_states + residual if residual is not None else hidden_states hidden_states = (
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params # update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1) inference_params.seqlen_offset += input_ids.size(1)
return logits, input_ids, inference_params return logits, input_ids, inference_params
\ No newline at end of file
...@@ -842,7 +842,6 @@ class FlashCausalLM(Model): ...@@ -842,7 +842,6 @@ class FlashCausalLM(Model):
else: else:
next_token_logits = out next_token_logits = out
speculate = get_speculate() speculate = get_speculate()
( (
next_input_ids, next_input_ids,
...@@ -1064,14 +1063,17 @@ class FlashCausalLM(Model): ...@@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
......
...@@ -26,6 +26,7 @@ from dataclasses import dataclass ...@@ -26,6 +26,7 @@ from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.generation import InferenceParams
@dataclass @dataclass
class MambaBatch(Batch): class MambaBatch(Batch):
batch_id: int batch_id: int
...@@ -69,7 +70,7 @@ class MambaBatch(Batch): ...@@ -69,7 +70,7 @@ class MambaBatch(Batch):
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
...@@ -196,7 +197,7 @@ class MambaBatch(Batch): ...@@ -196,7 +197,7 @@ class MambaBatch(Batch):
new_padding_right_offset = max( new_padding_right_offset = max(
new_padding_right_offset, remaining_decode_tokens new_padding_right_offset, remaining_decode_tokens
) )
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached # Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids = self.input_ids[keep_indices] input_ids = self.input_ids[keep_indices]
...@@ -218,10 +219,13 @@ class MambaBatch(Batch): ...@@ -218,10 +219,13 @@ class MambaBatch(Batch):
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {} key_value_memory_dict = {}
for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): for i, (
conv_state,
ssm_state,
) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict self.inference_params.key_value_memory_dict = key_value_memory_dict
...@@ -305,8 +309,9 @@ class MambaBatch(Batch): ...@@ -305,8 +309,9 @@ class MambaBatch(Batch):
start_index = end_index start_index = end_index
(_, d_model, d_conv) = (
(_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape batches[0].inference_params.key_value_memory_dict[0][0].shape
)
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict) n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
...@@ -344,9 +349,15 @@ class MambaBatch(Batch): ...@@ -344,9 +349,15 @@ class MambaBatch(Batch):
for i in range(n_blocks): for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state inference_params.key_value_memory_dict[i][0][
inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state current_batch : current_batch + batch_size
inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample ] = conv_state
inference_params.key_value_memory_dict[i][1][
current_batch : current_batch + batch_size
] = ssm_state
inference_params.lengths_per_sample[
current_batch : current_batch + batch_size
] = batch.inference_params.lengths_per_sample
current_batch += batch_size current_batch += batch_size
return cls( return cls(
...@@ -366,12 +377,13 @@ class MambaBatch(Batch): ...@@ -366,12 +377,13 @@ class MambaBatch(Batch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens, max_tokens=max_tokens,
inference_params=inference_params inference_params=inference_params,
) )
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
class Mamba(Model): class Mamba(Model):
def __init__( def __init__(
self, self,
...@@ -428,7 +440,7 @@ class Mamba(Model): ...@@ -428,7 +440,7 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]: def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed # TODO: implement warmup for Mamba if needed
return None return None
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -441,7 +453,9 @@ class Mamba(Model): ...@@ -441,7 +453,9 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids input_ids = (
batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
max_seqlen = input_ids.shape[1] max_seqlen = input_ids.shape[1]
...@@ -450,8 +464,11 @@ class Mamba(Model): ...@@ -450,8 +464,11 @@ class Mamba(Model):
# Inference params # Inference params
seqlen_og = 0 seqlen_og = 0
inf_cache = {} inf_cache = {}
lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen lengths_per_sample = (
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
* max_seqlen
)
if batch.inference_params is None: if batch.inference_params is None:
inference_params = InferenceParams( inference_params = InferenceParams(
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
...@@ -478,11 +495,16 @@ class Mamba(Model): ...@@ -478,11 +495,16 @@ class Mamba(Model):
device=block.dt_proj.weight.device, device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype, dtype=block.dt_proj.weight.dtype,
) )
inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) inference_params.key_value_memory_dict[block.layer_idx] = (
conv_state,
ssm_state,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) logits, past_input_ids, new_inference_params = self.model(
input_ids, batch.inference_params
)
batch.inference_params = new_inference_params batch.inference_params = new_inference_params
# Results # Results
...@@ -564,7 +586,8 @@ class Mamba(Model): ...@@ -564,7 +586,8 @@ class Mamba(Model):
prefix_offset=len(all_input_ids) prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens - stopping_criteria.current_tokens
- 1, - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens, read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True, skip_special_tokens=True,
) )
# Get seed # Get seed
......
...@@ -750,14 +750,17 @@ class Seq2SeqLM(Model): ...@@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip(top_token_ids, top_token_logprobs): for (top_token_ids, top_token_logprobs) in zip(
top_token_ids, top_token_logprobs
):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
top_token_ids, top_token_ids,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
skip_special_tokens=False, skip_special_tokens=False,
) )
special_toptokens = [ special_toptokens = [
token_id in self.all_special_ids for token_id in top_token_ids token_id in self.all_special_ids
for token_id in top_token_ids
] ]
top_tokens = Tokens( top_tokens = Tokens(
top_token_ids, top_token_ids,
......
...@@ -95,5 +95,7 @@ class Generation: ...@@ -95,5 +95,7 @@ class Generation:
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,
top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens] if self.top_tokens is not None else None, top_tokens=[top_tokens.to_pb() for top_tokens in self.top_tokens]
if self.top_tokens is not None
else None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment