Unverified Commit 610bb1f9 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(benchmark): tui based benchmarking tool (#149)

parent 55106ec4
......@@ -5,6 +5,9 @@ members = [
"router/grpc-metadata",
"launcher"
]
exclude = [
"benchmark"
]
[profile.release]
debug = 1
......
......@@ -7,6 +7,9 @@ install-router:
install-launcher:
cd launcher && cargo install --path .
install-benchmark:
cd benchmark && cargo install --path .
install: install-server install-router install-launcher
server-dev:
......
target
\ No newline at end of file
This diff is collapsed.
[package]
name = "text-generation-benchmark"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Benchmarking tool"
[profile.release]
debug = 1
incremental = true
lto = "off"
panic = "abort"
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-benchmark"
path = "src/main.rs"
[dependencies]
average = "0.13"
clap = { version = "4.1.4", features = ["derive", "env"] }
crossterm = "0.26"
float-ord = "0.3.2"
serde = {version = "1.0.142", features = ["derive"]}
serde_json = "1.0"
text-generation-client = { path = "../router/client" }
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
<div align="center">
# Text Generation Inference benchmarking tool
![benchmark](../assets/benchmark.png)
</div>
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
## Install
```shell
make install-benchmark
```
## Run
First, start `text-generation-inference`:
```shell
text-generation-launcher --model-id bigscience/bloom-560m
```
Then run the benchmarking tool:
```shell
text-generation-benchmark --tokenizer-name bigscience/bloom-560m
```
\ No newline at end of file
[toolchain]
channel = "1.67.0"
components = ["rustfmt", "clippy"]
\ No newline at end of file
This diff is collapsed.
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
use crossterm::event;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};
/// Events
#[derive(Debug)]
pub(crate) enum Event {
/// Terminal tick.
Tick,
/// Key press.
Key(event::KeyEvent),
/// Terminal resize.
Resize(u16, u16),
}
pub(crate) async fn terminal_event_task(
fps: u32,
event_sender: mpsc::Sender<Event>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
_ = event_loop(fps, event_sender) => {
},
_ = shutdown_receiver.recv() => {}
}
}
/// Main event loop
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
// Frame budget
let per_frame = Duration::from_secs(1) / fps;
// When was last frame executed
let mut last_frame = Instant::now();
loop {
// Sleep to avoid blocking the thread for too long
if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {
tokio::time::sleep(sleep).await;
}
// Get crossterm event and send a new one over the channel
if event::poll(Duration::from_secs(0)).expect("no events available") {
match event::read().expect("unable to read event") {
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
event::Event::Resize(w, h) => {
event_sender.send(Event::Resize(w, h)).await.unwrap_or(())
}
_ => (),
}
}
// Frame budget exceeded
if last_frame.elapsed() >= per_frame {
// Send tick
event_sender.send(Event::Tick).await.unwrap_or(());
// Rest last_frame time
last_frame = Instant::now();
}
}
}
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) latency: Duration,
pub(crate) token_latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
EndRun,
EndBatch,
}
/// Benchmarking task
#[allow(clippy::too_many_arguments)]
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
}
/// Benchmark prefill/decode
#[allow(clippy::too_many_arguments)]
async fn generate_runs(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
let (_, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?;
let _ = decode(decode_batch, &mut client).await?;
// Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?;
// Send prefill message
run_sender
.send(Ok(Message::Prefill(prefill)))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
// Send decode message
run_sender
.send(Ok(Message::Decode(decode)))
.await
.unwrap_or(());
// Send run ended message
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
}
// Batch ended
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
// Run a prefill step
async fn prefill(
sequence: String,
batch_size: u32,
decode_length: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
inputs: sequence.clone(),
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
};
// Run prefill
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();
// Compute throughput from latency and batch size
let throughput = batch_size as f64 / latency.as_secs_f64();
// Decode batch cannot be empty
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
/// Run a full decode
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;
let start_time = Instant::now();
// Full decode over decode length
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
// Get latency
let latency = start_time.elapsed();
let token_latency = latency / decode_length;
// Compute throughput from latency, batch size and decode length
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
let step = Decode {
latency,
token_latency,
throughput,
};
Ok(step)
}
/// Create a dummy sequence of the correct length
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
}
mod app;
mod event;
mod generation;
mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]
pub async fn run(
tokenizer_name: String,
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
client: ShardedClient,
) -> Result<(), crossterm::ErrorKind> {
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
let backend = CrosstermBackend::new(io::stdout());
Terminal::new(backend)?
};
// Create message channel between generation_task and app
let (run_sender, run_receiver) = mpsc::channel(8);
// Crossterm event channel
let (event_sender, mut event_receiver) = mpsc::channel(8);
// Shutdown channel to terminate tasks
let (shutdown_sender, _) = broadcast::channel(1);
// Channel to check if tasks terminated
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
// Create generation task
tokio::spawn(generation::generation_task(
tokenizer,
batch_size.clone(),
sequence_length,
decode_length,
n_runs,
warmups,
client,
run_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Create event task
tokio::spawn(event::terminal_event_task(
250,
event_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Drop our end of shutdown sender
drop(shutdown_guard_sender);
// Create App
let mut app = App::new(
run_receiver,
tokenizer_name,
sequence_length,
decode_length,
n_runs,
batch_size,
);
while app.running {
// Draw frame
terminal.draw(|frame| app.render(frame))?;
// Await a new event from event handling task
match event_receiver.recv().await {
None => break,
// Update app state
Some(event) => match event {
Event::Tick => app.tick(),
Event::Key(key_event) => app.handle_key_event(key_event),
_ => {}
},
}
}
// Ask tasks to shutdown
let _ = shutdown_sender.send(());
// Wait for tasks to shutdown
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
Ok(())
}
/// Text Generation Inference benchmarking tool
///
/// Inspired by the great Oha app: https://github.com/hatoo/oha
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long, env)]
tokenizer_name: String,
#[clap(short, long)]
batch_size: Option<Vec<u32>>,
#[clap(default_value = "10", short, long, env)]
sequence_length: u32,
#[clap(default_value = "8", short, long, env)]
decode_length: u32,
#[clap(default_value = "10", short, long, env)]
runs: usize,
#[clap(default_value = "1", short, long, env)]
warmups: usize,
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
master_shard_uds_path: String,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
tokenizer_name,
batch_size,
sequence_length,
decode_length,
runs,
warmups,
master_shard_uds_path,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
init_logging();
// Tokenizer instance
// This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
// Instantiate sharded client from the master unix socket
tracing::info!("Connect to model server");
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
// Run app
text_generation_benchmark::run(
tokenizer_name,
tokenizer,
batch_size,
sequence_length,
decode_length,
runs,
warmups,
sharded_client,
)
.await
.unwrap();
});
Ok(())
}
/// Init logging using LOG_LEVEL
fn init_logging() {
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer)
.init();
}
/// MIT License
//
// Copyright (c) 2020 hatoo
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
use std::collections::BTreeMap;
pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {
assert!(bins >= 2);
let mut bucket: Vec<usize> = vec![0; bins];
let min = values.iter().collect::<average::Min>().min();
let max = values.iter().collect::<average::Max>().max();
let step = (max - min) / (bins - 1) as f64;
for &v in values {
let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);
bucket[i] += 1;
}
bucket
.into_iter()
.enumerate()
.map(|(i, v)| (min + step * i as f64, v))
.collect()
}
pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {
pecents
.iter()
.map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
})
.collect()
}
......@@ -53,6 +53,9 @@ message StoppingCriteriaParameters {
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {
......
......@@ -37,7 +37,7 @@ struct Args {
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-0", long, env)]
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
......@@ -76,6 +76,8 @@ fn main() -> Result<(), std::io::Error> {
panic!("validation_workers must be > 0");
}
init_logging(otlp_endpoint, json_output);
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
......@@ -89,17 +91,21 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance
// This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
......@@ -107,8 +113,6 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
init_logging(otlp_endpoint, json_output);
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}"
......
......@@ -237,6 +237,7 @@ mod tests {
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
max_new_tokens: 0,
stop_sequences: vec![],
},
......
......@@ -315,6 +315,7 @@ fn validate(
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_input_length", input_length as f64);
......
......@@ -18,7 +18,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
......
......@@ -123,20 +123,22 @@ class StoppingCriteria:
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.ignore_eos_token = ignore_eos_token
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
if not self.ignore_eos_token and last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
......@@ -156,5 +158,8 @@ class StoppingCriteria:
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
tokenizer.eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment