"vscode:/vscode.git/clone" did not exist on "71907be1ef41240b6eae6239eb29191d0aa4cd52"
Unverified Commit e415b690 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Lots of improvements (Still 2 allocators) (#2449)



* Making prefix/flashinfer the default and testing the full release tests.

* Include flashinfer in the docker.

* Using prebuilt.

* Allowing window_left_size (dummy version).

* Disabling flashinfer/prefix caching on odd head_dim

* Disable prefix caching for lora.

* More specific codes.

* Update lock

* Updating integration tests with new values with FI/FD.

Remove paged as a default too, and using FD everywhere.

* Update cargo lock ?

* Upgrade to 1.80 because of bitstream...

* Everywhere 1.80

* Forgot last default place.

* Apply suggestions from code review
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>

* Updated flake lock

* Tmp

* Upgrade resolution system for less errors in resolution.

* Remove lambda for cleaner function.

* Handling debugger.

* OVerride the env in server tests.

* Is this enough to make it work ?

* This seems to be working.

* Downgrade some logs.

* Fixing the default for vlm.

* Don't enable prefix caching on VLM just yet.

* Change `add_special_tokens` in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)

* Fixing prefix caching for flashdecoding.

* Update all models.

* Fixed flashinfer version.

* add_special_tokens is internal only

* Fixing seqlen with the new vlms.

* Fixing the issue with `add_special_tokens` not being passed around.

* Fixing the test.

* Removing encoder_decoder (seq2seq).

* Update the chat test.

* Fixing the batching tokenization in flash causal lm.

* Truncating left for radix purposes.

* Oops this doesn't belong here.

* Put back default pure shell.

* Update server tests

- Default to throughput test in k6
- Use TGI_WIGGLE_ROOM to adjust wiggle room

* Only n_heads / process_group.size() are necessary.

* Revert the integrationt tests change (seem linked to head_size
modification).

* Adding error message when assert is violated.

* Fixing the free algorithm to handle times where the common prefix is
smaller.

* Apply suggestions from code review
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Update server/text_generation_server/layers/attention/common.py
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>

* Fix disabling prefix caching - Fix windowing checks.

* Revert the Cohere tokenizer change (for now using a revision instead).

* Fmt.

---------
Co-authored-by: default avatardrbh <david.richard.holtz@gmail.com>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
parent 4e821c00
......@@ -16,7 +16,7 @@
},
{
"id": 100,
"logprob": -0.38549805,
"logprob": -0.38305664,
"text": "_"
},
{
......@@ -29,7 +29,7 @@
"tokens": [
{
"id": 2284,
"logprob": -0.31323242,
"logprob": -0.296875,
"special": false,
"text": "():"
},
......@@ -59,19 +59,19 @@
},
{
"id": 10914,
"logprob": -0.7817383,
"logprob": -0.7734375,
"special": false,
"text": " World"
},
{
"id": 16013,
"logprob": -0.6328125,
"logprob": -0.61816406,
"special": false,
"text": "!\")"
},
{
"id": 222,
"logprob": -0.0619812,
"logprob": -0.054870605,
"special": false,
"text": "\n"
},
......@@ -83,7 +83,7 @@
},
{
"id": 610,
"logprob": -0.4086914,
"logprob": -0.4152832,
"special": false,
"text": "def"
},
......@@ -113,7 +113,7 @@
},
{
"id": 444,
"logprob": -0.21826172,
"logprob": -0.21618652,
"special": false,
"text": "name"
},
......@@ -173,7 +173,7 @@
},
{
"id": 11571,
"logprob": -0.10021973,
"logprob": -0.08892822,
"special": false,
"text": "!\""
},
......
......@@ -30,19 +30,19 @@
},
{
"id": 264,
"logprob": -0.37573242,
"logprob": -0.38061523,
"special": false,
"text": " a"
},
{
"id": 633,
"logprob": -0.09161377,
"logprob": -0.09301758,
"special": false,
"text": " new"
},
{
"id": 4480,
"logprob": -0.26171875,
"logprob": -0.26782227,
"special": false,
"text": " feature"
},
......@@ -78,7 +78,7 @@
},
{
"id": 13,
"logprob": 0.0,
"logprob": -0.10632324,
"special": false,
"text": "\n"
}
......
......@@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
print(repr(response.choices[0].message.content))
assert (
response.choices[0].message.content
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas"
== "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C"
)
assert response == response_snapshot
......@@ -8,7 +8,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines};
use std::io::{BufRead, BufReader};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
......@@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc};
use std::thread;
use std::thread::sleep;
use std::time::{Duration, Instant};
use std::{fs, io};
use std::{
fs, io,
io::{Read, Write},
};
use thiserror::Error;
use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
fn get_config(
model_id: &str,
revision: &Option<String>,
) -> Result<Config, Box<dyn std::error::Error>> {
let mut path = std::path::Path::new(model_id).to_path_buf();
let model_id = model_id.to_string();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
let config: Config = config.into();
Ok(config)
}
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
if config.vision_config.is_some() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
} else if config.is_encoder_decoder {
tracing::info!("Disabling prefix caching because of seq2seq model");
prefix_caching = Some("0".to_string());
}
}
match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
attention = Some("flashdecoding".to_string());
}
}
Some("t5") => {}
_ => {}
}
}
_ => {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
attention = Some("flashdecoding".to_string());
}
if prefix_caching.is_none() {
prefix_caching = Some("0".to_string());
}
}
}
}
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
let attention = attention.unwrap_or("flashinfer".to_string());
(prefix_caching, attention)
}
#[derive(Deserialize)]
struct RawConfig {
max_position_embeddings: Option<usize>,
......@@ -31,6 +122,12 @@ struct RawConfig {
model_type: Option<String>,
max_seq_len: Option<usize>,
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
}
#[derive(Deserialize)]
......@@ -38,10 +135,17 @@ struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
impl From<RawConfig> for Config {
......@@ -51,9 +155,32 @@ impl From<RawConfig> for Config {
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
is_encoder_decoder,
}
}
}
......@@ -731,6 +858,7 @@ fn shard_manager(
.args(shard_args)
.env_clear()
.envs(envs)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
......@@ -752,12 +880,13 @@ fn shard_manager(
};
// Redirect STDOUT to the console
let mut pstdin = p.stdin.take().unwrap();
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
//stdout tracing thread
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
log_lines(shard_stdout_reader);
});
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
......@@ -766,6 +895,18 @@ fn shard_manager(
err_sender.send(line).unwrap_or(());
}
});
// We read stdin in another thread as it seems that lines() can block in some cases
thread::spawn(move || {
let mut stdin = io::stdin(); // We get `Stdin` here.
loop {
let mut buffer = vec![0; 4096];
if let Ok(n) = stdin.read(&mut buffer) {
if n > 0 {
let _ = pstdin.write_all(&buffer[..n]);
}
}
}
});
let mut ready = false;
let start_time = Instant::now();
......@@ -872,19 +1013,36 @@ impl PythonLogMessage {
}
}
impl TryFrom<&String> for PythonLogMessage {
impl TryFrom<&[u8]> for PythonLogMessage {
type Error = serde_json::Error;
fn try_from(value: &String) -> Result<Self, Self::Error> {
serde_json::from_str::<Self>(value)
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice::<Self>(value)
}
}
fn log_lines<S: Sized + BufRead>(lines: Lines<S>) {
for line in lines.map_while(Result::ok) {
match PythonLogMessage::try_from(&line) {
Ok(log) => log.trace(),
Err(_) => tracing::debug!("{line}"),
fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
let mut buffer = vec![0u8; 8 * 4096];
let mut stdout = std::io::stdout();
loop {
let n = bufread.read(&mut buffer);
if let Ok(n) = n {
if n > 0 {
let mut lines = buffer[..n].split(|i| *i == b'\n').peekable();
while let Some(line) = lines.next() {
match PythonLogMessage::try_from(line) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
}
stdout.flush().unwrap();
}
}
}
}
}
}
}
......@@ -1044,7 +1202,7 @@ fn download_convert_model(
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());
thread::spawn(move || {
log_lines(download_stdout.lines());
log_lines(download_stdout);
});
let download_stderr = BufReader::new(download_process.stderr.take().unwrap());
......@@ -1439,68 +1597,35 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:#?}", args);
let get_max_positions_quantize =
|| -> Result<(usize, Option<Quantization>), Box<dyn std::error::Error>> {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = if let Ok(token) = std::env::var("HF_TOKEN") {
// env variable has precedence over on file token.
ApiBuilder::new().with_token(Some(token)).build()?
} else {
Api::new()?
};
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json")?
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename)?;
let config: RawConfig = serde_json::from_str(&content)?;
if config.model_type == Some("gemma2".to_string()) {
tracing::info!("Forcing flash decoding because of softcap usage");
std::env::set_var("ATTENTION", "flashdecoding");
}
let config: Config = config.into();
let quantize = config.quantize;
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
Ok((max_default, quantize))
} else {
Ok((max_position_embeddings, quantize))
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
Err(Box::new(LauncherError::ArgumentValidation(
"no max defined".to_string(),
)))
max_position_embeddings
}
};
let (max_position_embeddings, quantize): (usize, Option<Quantization>) =
get_max_positions_quantize().unwrap_or((4096, None));
} else {
max_default
}
} else {
max_default
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
......
......@@ -33,13 +33,13 @@ export function get_options() {
// rate: 20,
// timeUnit: '1s',
// },
load_test: {
executor: 'constant-arrival-rate',
duration: '60s',
preAllocatedVUs: 100,
rate: 1,
timeUnit: '1s',
},
// load_test: {
// executor: 'constant-arrival-rate',
// duration: '60s',
// preAllocatedVUs: 100,
// rate: 1,
// timeUnit: '1s',
// },
// breakpoint: {
// executor: 'ramping-arrival-rate', //Assure load increase if the system slows
// preAllocatedVUs: 300,
......@@ -47,12 +47,12 @@ export function get_options() {
// { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load
// ],
// },
// throughput: {
// executor: 'shared-iterations',
// vus: 100,
// iterations: 200,
// maxDuration: '40s',
// },
throughput: {
executor: 'shared-iterations',
vus: 100,
iterations: 200,
maxDuration: '40s',
},
},
};
}
......
......@@ -137,6 +137,8 @@ message Request {
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
}
message Batch {
......
......@@ -120,10 +120,11 @@ impl Infer {
) -> Result<Option<tokenizers::Encoding>, InferError> {
// Tokenize request
let inputs = request.inputs;
let add_special_tokens = request.add_special_tokens;
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.tokenize(inputs, add_special_tokens, truncate)
.await
.map_err(|err| {
tracing::error!("Tokenization {err}");
......
......@@ -22,6 +22,16 @@ pub enum Attention {
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
......@@ -1072,6 +1082,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}
fn default_true() -> bool {
true
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
......@@ -1089,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
add_special_tokens: true,
parameters: req.parameters,
}
}
......
......@@ -158,6 +158,7 @@ async fn get_chat_tokenize(
let generate_request = GenerateRequest {
inputs,
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
......@@ -754,6 +755,7 @@ async fn completions(
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
best_of: None,
temperature,
......@@ -1180,6 +1182,7 @@ async fn chat_completions(
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
......@@ -1386,6 +1389,7 @@ async fn vertex_compatibility(
.map(|instance| {
let generate_request = GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
......
......@@ -95,6 +95,7 @@ impl Validation {
pub async fn tokenize(
&self,
inputs: String,
add_special_tokens: bool,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer
......@@ -104,7 +105,11 @@ impl Validation {
// Send request to the background validation task
// Unwrap is safe here
sender
.send(((inputs, truncate), response_sender, Span::current()))
.send((
(inputs, add_special_tokens, truncate),
response_sender,
Span::current(),
))
.unwrap();
// Await on response channel
......@@ -121,11 +126,15 @@ impl Validation {
async fn validate_input(
&self,
inputs: String,
add_special_tokens: bool,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
if let Some((encoding, inputs)) = self
.tokenize(inputs.clone(), add_special_tokens, truncate)
.await?
{
// Create response channel
let input_length = if let Some(truncate) = truncate {
std::cmp::min(encoding.len(), truncate)
......@@ -158,7 +167,8 @@ impl Validation {
));
}
let input_ids = encoding.get_ids()[..input_length].to_owned();
let ids = encoding.get_ids();
let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, Some(input_ids), input_length, max_new_tokens))
......@@ -324,7 +334,12 @@ impl Validation {
// Validate inputs
let (inputs, input_ids, input_length, max_new_tokens) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.validate_input(
request.inputs,
request.add_special_tokens,
truncate,
max_new_tokens,
)
.await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
......@@ -401,6 +416,7 @@ impl Validation {
Ok(ValidGenerateRequest {
inputs,
input_ids: input_ids.map(Arc::new),
add_special_tokens: request.add_special_tokens,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
......@@ -449,12 +465,15 @@ fn tokenizer_worker(
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
) {
// Loop over requests
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
receiver.blocking_recv()
{
parent_span.in_scope(|| {
response_tx
.send(prepare_input(
inputs,
truncate,
add_special_tokens,
&tokenizer,
config.as_ref(),
preprocessor_config.as_ref(),
......@@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
fn prepare_input(
inputs: String,
_truncate: Option<usize>,
add_special_tokens: bool,
tokenizer: &Tokenizer,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
......@@ -628,14 +648,14 @@ fn prepare_input(
// Get the number of tokens in the input
let encoding = tokenizer
.encode(tokenizer_query, true)
.encode(tokenizer_query, add_special_tokens)
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
Ok((encoding, input_chunks))
}
type TokenizerRequest = (
(String, Option<usize>),
(String, bool, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span,
);
......@@ -720,6 +740,7 @@ pub struct ValidGenerateRequest {
pub input_ids: Option<Arc<Vec<u32>>>,
pub input_length: u32,
pub truncate: u32,
pub add_special_tokens: bool,
pub decoder_input_details: bool,
pub parameters: ValidParameters,
pub stopping_parameters: ValidStoppingParameters,
......@@ -826,7 +847,7 @@ mod tests {
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await
{
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
......@@ -861,7 +882,7 @@ mod tests {
let max_new_tokens = 10;
match validation
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
.await
{
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
......@@ -895,6 +916,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
best_of: Some(2),
do_sample: false,
......@@ -934,6 +956,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: Some(1.0),
max_new_tokens: Some(5),
......@@ -949,6 +972,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: Some(0.99),
max_new_tokens: Some(5),
......@@ -964,6 +988,7 @@ mod tests {
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_p: None,
max_new_tokens: Some(5),
......@@ -1002,6 +1027,7 @@ mod tests {
match validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(5),
max_new_tokens: Some(5),
......@@ -1017,6 +1043,7 @@ mod tests {
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: Some(5),
......@@ -1029,6 +1056,7 @@ mod tests {
validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: Some(5),
......@@ -1041,6 +1069,7 @@ mod tests {
let valid_request = validation
.validate(GenerateRequest {
inputs: "Hello".to_string(),
add_special_tokens: true,
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: Some(5),
......@@ -1089,6 +1118,7 @@ mod tests {
let chunks = match validation
.tokenize(
format!("test![](data:image/gif;base64,{})", PIXEL_GIF),
true,
None,
)
.await
......@@ -1148,6 +1178,7 @@ mod tests {
"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})",
PIXEL_GIF, PIXEL_GIF
),
true,
None,
)
.await
......
[toolchain]
# Released on: June 13, 2024
# https://releases.rs/docs/1.79.0/
channel = "1.79.0"
channel = "1.80.0"
components = ["rustfmt", "clippy"]
......@@ -7,6 +7,7 @@ include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
unit-tests:
pytest -s -vv -m "not private" tests
......
install-flashinfer:
pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4
import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"
@pytest.fixture
def default_pb_parameters():
......
......@@ -9,26 +9,46 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(self, input_lengths):
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
......@@ -39,6 +59,11 @@ else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))
......@@ -222,18 +222,15 @@ if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
assert window_size_left == -1, "Windowing is not supported with flash infer"
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
......@@ -244,18 +241,17 @@ if ATTENTION == "flashinfer":
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif V2:
def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
......@@ -266,17 +262,17 @@ elif V2:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
max_s,
max_s,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
......
......@@ -497,15 +497,14 @@ def get_model(
else -1
)
should_use_sliding_window = (
sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING
use_sliding_window = sliding_window is not None and sliding_window != -1
needs_sliding_window = (
max_input_tokens is not None and max_input_tokens > sliding_window
)
if should_use_sliding_window:
if max_input_tokens is not None and max_input_tokens > sliding_window:
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:
raise ValueError(
f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})."
)
if model_type == DEEPSEEK_V2:
if FLASH_ATTENTION:
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
......@@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -29,6 +29,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
......@@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
......@@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
......@@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -622,7 +621,7 @@ class DbrxLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
):
# Self Attention
......@@ -635,7 +634,7 @@ class DbrxLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
......@@ -29,8 +29,8 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers.attention.common import Seqlen
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
......@@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module):
kv_cache: Tuple[torch.Tensor, torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
):
if self.q_lora_rank is None:
......@@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
key,
value,
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
......@@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
......@@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
......@@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
......@@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
......@@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
)
if lm_head_indices is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment