Commit 366dfe82 authored by jixx's avatar jixx
Browse files

init

parents
Pipeline #1939 canceled with stages
use std::time::{Duration, Instant};
use text_generation_client::v3::{
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) latency: Duration,
pub(crate) token_latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
EndRun,
EndBatch,
}
/// Benchmarking task
#[allow(clippy::too_many_arguments)]
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
}
/// Benchmark prefill/decode
#[allow(clippy::too_many_arguments)]
async fn generate_runs(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
parameters: NextTokenChooserParameters,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
let (_, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
let _ = decode(decode_batch, &mut client).await?;
// Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) = prefill(
sequence.clone(),
sequence_length,
b,
decode_length,
parameters.clone(),
top_n_tokens,
&mut client,
)
.await?;
// Send prefill message
run_sender
.send(Ok(Message::Prefill(prefill)))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
// Send decode message
run_sender
.send(Ok(Message::Decode(decode)))
.await
.unwrap_or(());
// Send run ended message
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
}
// Batch ended
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
// Run a prefill step
async fn prefill(
sequence: String,
sequence_length: u32,
batch_size: u32,
decode_length: u32,
parameters: NextTokenChooserParameters,
top_n_tokens: Option<u32>,
client: &mut ShardedClient,
) -> Result<(Prefill, CachedBatch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(),
truncate: sequence_length,
parameters: Some(parameters.clone()),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
};
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();
// Compute throughput from latency and batch size
let throughput = batch_size as f64 / latency.as_secs_f64();
// Decode batch cannot be empty
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
/// Run a full decode
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;
let start_time = Instant::now();
// Full decode over decode length
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
// Get latency
let latency = start_time.elapsed();
let token_latency = latency / decode_length;
// Compute throughput from latency, batch size and decode length
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
let step = Decode {
latency,
token_latency,
throughput,
};
Ok(step)
}
/// Create a dummy sequence of the correct length
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer.decode(encoding.get_ids(), false).unwrap()
}
mod app;
mod event;
mod generation;
mod table;
mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]
pub async fn run(
tokenizer_name: String,
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
top_k: Option<u32>,
top_p: Option<f32>,
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
client: ShardedClient,
) -> Result<(), std::io::Error> {
let parameters = NextTokenChooserParameters {
temperature: temperature.unwrap_or(1.0),
top_k: top_k.unwrap_or(0),
top_p: top_p.unwrap_or(1.0),
typical_p: typical_p.unwrap_or(1.0),
do_sample,
seed: 0,
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
};
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
let backend = CrosstermBackend::new(io::stdout());
Terminal::new(backend)?
};
// Create message channel between generation_task and app
let (run_sender, run_receiver) = mpsc::channel(8);
// Crossterm event channel
let (event_sender, mut event_receiver) = mpsc::channel(8);
// Shutdown channel to terminate tasks
let (shutdown_sender, _) = broadcast::channel(1);
// Channel to check if tasks terminated
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
// Create generation task
tokio::spawn(generation::generation_task(
tokenizer,
batch_size.clone(),
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
parameters,
client,
run_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Create event task
tokio::spawn(event::terminal_event_task(
250,
event_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Drop our end of shutdown sender
drop(shutdown_guard_sender);
// Create App
let mut app = App::new(
run_receiver,
tokenizer_name.clone(),
sequence_length,
decode_length,
n_runs,
batch_size,
);
while app.running {
// Draw frame
terminal.draw(|frame| app.render(frame))?;
// Await a new event from event handling task
match event_receiver.recv().await {
None => break,
// Update app state
Some(event) => match event {
Event::Tick => app.tick(),
Event::Key(key_event) => app.handle_key_event(key_event),
_ => {}
},
}
}
// Ask tasks to shutdown
let _ = shutdown_sender.send(());
// Wait for tasks to shutdown
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
let parameters_table = table::parameters_table(
tokenizer_name,
sequence_length,
decode_length,
top_n_tokens,
n_runs,
warmups,
temperature,
top_k,
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
);
println!("\n{parameters_table}\n");
let latency_table = table::latency_table(&app.data);
println!("\n{latency_table}\n");
let throughput_table = table::throughput_table(&app.data);
println!("\n{throughput_table}\n");
Ok(())
}
/// Text Generation Inference benchmarking tool
///
/// Inspired by the great Oha app: https://github.com/hatoo/oha
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::v3::ShardedClient;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// The name of the tokenizer (as in model_id on the huggingface hub, or local path).
#[clap(short, long, env)]
tokenizer_name: String,
/// The revision to use for the tokenizer if on the hub.
#[clap(default_value = "main", long, env)]
revision: String,
/// The various batch sizes to benchmark for, the idea is to get enough
/// batching to start seeing increased latency, this usually means you're
/// moving from memory bound (usual as BS=1) to compute bound, and this is
/// a sweet spot for the maximum batch size for the model under test
#[clap(short, long)]
batch_size: Option<Vec<u32>>,
/// This is the initial prompt sent to the text-generation-server length
/// in token. Longer prompt will slow down the benchmark. Usually the
/// latency grows somewhat linearly with this for the prefill step.
///
/// Most importantly, the prefill step is usually not the one dominating
/// your runtime, so it's ok to keep it short.
#[clap(default_value = "10", short, long, env)]
sequence_length: u32,
/// This is how many tokens will be generated by the server and averaged out
/// to give the `decode` latency. This is the *critical* number you want to optimize for
/// LLM spend most of their time doing decoding.
///
/// Decode latency is usually quite stable.
#[clap(default_value = "8", short, long, env)]
decode_length: u32,
///How many runs should we average from
#[clap(default_value = "10", short, long, env)]
runs: usize,
/// Number of warmup cycles
#[clap(default_value = "1", short, long, env)]
warmups: usize,
/// The location of the grpc socket. This benchmark tool bypasses the router
/// completely and directly talks to the gRPC processes
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
master_shard_uds_path: String,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
temperature: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_k: Option<u32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_p: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
typical_p: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
repetition_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
frequency_penalty: Option<f32>,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
watermark: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
do_sample: bool,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long, env)]
top_n_tokens: Option<u32>,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
init_logging();
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
tokenizer_name,
revision,
batch_size,
sequence_length,
decode_length,
runs,
warmups,
temperature,
top_k,
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
master_shard_uds_path,
top_n_tokens,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
// Tokenizer instance
// This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token
let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
let params = FromPretrainedParameters {
revision,
auth_token,
..Default::default()
};
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
};
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
// Instantiate sharded client from the master unix socket
tracing::info!("Connect to model server");
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
// Run app
text_generation_benchmark::run(
tokenizer_name,
tokenizer,
batch_size,
sequence_length,
decode_length,
top_n_tokens,
runs,
warmups,
temperature,
top_k,
top_p,
typical_p,
repetition_penalty,
frequency_penalty,
watermark,
do_sample,
sharded_client,
)
.await
.unwrap();
});
Ok(())
}
/// Init logging using LOG_LEVEL
fn init_logging() {
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer)
.init();
}
use crate::app::Data;
use tabled::settings::Merge;
use tabled::{builder::Builder, settings::Style, Table};
#[allow(clippy::too_many_arguments)]
pub(crate) fn parameters_table(
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
top_n_tokens: Option<u32>,
n_runs: usize,
warmups: usize,
temperature: Option<f32>,
top_k: Option<u32>,
top_p: Option<f32>,
typical_p: Option<f32>,
repetition_penalty: Option<f32>,
frequency_penalty: Option<f32>,
watermark: bool,
do_sample: bool,
) -> Table {
let mut builder = Builder::default();
builder.set_header(["Parameter", "Value"]);
builder.push_record(["Model", &tokenizer_name]);
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
builder.push_record(["Decode Length", &decode_length.to_string()]);
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
builder.push_record(["N Runs", &n_runs.to_string()]);
builder.push_record(["Warmups", &warmups.to_string()]);
builder.push_record(["Temperature", &format!("{temperature:?}")]);
builder.push_record(["Top K", &format!("{top_k:?}")]);
builder.push_record(["Top P", &format!("{top_p:?}")]);
builder.push_record(["Typical P", &format!("{typical_p:?}")]);
builder.push_record(["Repetition Penalty", &format!("{repetition_penalty:?}")]);
builder.push_record(["Frequency Penalty", &format!("{frequency_penalty:?}")]);
builder.push_record(["Watermark", &watermark.to_string()]);
builder.push_record(["Do Sample", &do_sample.to_string()]);
let mut table = builder.build();
table.with(Style::markdown());
table
}
pub(crate) fn latency_table(data: &Data) -> Table {
let mut builder = Builder::default();
builder.set_header([
"Step",
"Batch Size",
"Average",
"Lowest",
"Highest",
"p50",
"p90",
"p99",
]);
add_latencies(
&mut builder,
"Prefill",
&data.batch_size,
&data.prefill_latencies,
);
add_latencies(
&mut builder,
"Decode (token)",
&data.batch_size,
&data.decode_token_latencies,
);
add_latencies(
&mut builder,
"Decode (total)",
&data.batch_size,
&data.decode_latencies,
);
let mut table = builder.build();
table.with(Style::markdown()).with(Merge::vertical());
table
}
pub(crate) fn throughput_table(data: &Data) -> Table {
let mut builder = Builder::default();
builder.set_header(["Step", "Batch Size", "Average", "Lowest", "Highest"]);
add_throuhgputs(
&mut builder,
"Prefill",
&data.batch_size,
&data.prefill_throughputs,
);
add_throuhgputs(
&mut builder,
"Decode",
&data.batch_size,
&data.decode_throughputs,
);
let mut table = builder.build();
table.with(Style::markdown()).with(Merge::vertical());
table
}
fn add_latencies(
builder: &mut Builder,
step: &'static str,
batch_size: &[u32],
batch_latencies: &[Vec<f64>],
) {
for (i, b) in batch_size.iter().enumerate() {
let latencies = &batch_latencies[i];
let (avg, min, max) = avg_min_max(latencies);
let row = [
step,
&b.to_string(),
&format_value(avg, "ms"),
&format_value(min, "ms"),
&format_value(max, "ms"),
&format_value(px(latencies, 50), "ms"),
&format_value(px(latencies, 90), "ms"),
&format_value(px(latencies, 99), "ms"),
];
builder.push_record(row);
}
}
fn add_throuhgputs(
builder: &mut Builder,
step: &'static str,
batch_size: &[u32],
batch_throughputs: &[Vec<f64>],
) {
for (i, b) in batch_size.iter().enumerate() {
let throughputs = &batch_throughputs[i];
let (avg, min, max) = avg_min_max(throughputs);
let row = [
step,
&b.to_string(),
&format_value(avg, "tokens/secs"),
&format_value(min, "tokens/secs"),
&format_value(max, "tokens/secs"),
];
builder.push_record(row);
}
}
fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
let average = data.iter().sum::<f64>() / data.len() as f64;
let min = data
.iter()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&f64::NAN);
let max = data
.iter()
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&f64::NAN);
(average, *min, *max)
}
fn px(data: &[f64], p: u32) -> f64 {
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
*data.get(i).unwrap_or(&f64::NAN)
}
fn format_value(value: f64, unit: &'static str) -> String {
format!("{:.2} {unit}", value)
}
/// MIT License
//
// Copyright (c) 2020 hatoo
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
use std::collections::BTreeMap;
pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {
assert!(bins >= 2);
let mut bucket: Vec<usize> = vec![0; bins];
let min = values.iter().collect::<average::Min>().min();
let max = values.iter().collect::<average::Max>().max();
let step = (max - min) / (bins - 1) as f64;
for &v in values {
let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);
bucket[i] += 1;
}
bucket
.into_iter()
.enumerate()
.map(|(i, v)| (min + step * i as f64, v))
.collect()
}
pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {
pecents
.iter()
.map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
})
.collect()
}
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
transformers
safetensors
unit-tests:
python -m pytest --cov=text_generation tests
install:
pip install pip --upgrade
pip install -e .
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference` instance running on
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
## Get Started
### Install
```shell
pip install text-generation
```
### Inference API Usage
```python
from text_generation import InferenceAPIClient
client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import InferenceAPIAsyncClient
client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
Check all currently deployed models on the Huggingface Inference API with `Text Generation` support:
```python
from text_generation.inference_api import deployed_models
print(deployed_models())
```
### Hugging Face Inference Endpoint usage
```python
from text_generation import Client
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```python
from text_generation import AsyncClient
endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
```
### Types
```python
# enum for grammar type
class GrammarType(Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar:
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters:
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
max_new_tokens: int
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float]
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float]
# Whether to prepend the prompt to the generated text
return_full_text: bool
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str]
# Random sampling seed
seed: Optional[int]
# The value used to module the logits distribution.
temperature: Optional[float]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
# truncate inputs tokens to the given size
truncate: Optional[int]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool
# Get generation details
details: bool
# Get decoder input token logprobs and ids
decoder_input_details: bool
# Return the N most likely tokens at each step
top_n_tokens: Optional[int]
# grammar to use for generation
grammar: Optional[Grammar]
class Request:
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
# Whether to stream output tokens
stream: bool
# Decoder input tokens
class InputToken:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
# Generated tokens
class Token:
# Token ID from the model tokenizer
id: int
# Token text
text: str
# Logprob
logprob: Optional[float]
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
# Generation finish reason
class FinishReason(Enum):
# number of generated tokens == `max_new_tokens`
Length = "length"
# the model generated its end of sequence token
EndOfSequenceToken = "eos_token"
# the model generated a text included in `stop_sequences`
StopSequence = "stop_sequence"
# Additional sequences when using the `best_of` parameter
class BestOfSequence:
# Generated text
generated_text: str
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# `generate` details
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Most likely tokens
top_tokens: Optional[List[List[Token]]]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
# `generate` return value
class Response:
# Generated text
generated_text: str
# Generation details
details: Details
# `generate_stream` details
class StreamDetails:
# Generation finish reason
finish_reason: FinishReason
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
# `generate_stream` return value
class StreamResponse:
# Generated token
token: Token
# Most likely tokens
top_tokens: Optional[List[Token]]
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
# Inference API currently deployed model
class DeployedModel:
model_id: str
sha: str
```
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.7.0"
DEPRECATION_WARNING = (
"`text_generation` clients are deprecated and will be removed in the near future. "
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment