Commit f16f2f5a authored by Olivier Dehaene's avatar Olivier Dehaene Committed by OlivierDehaene
Browse files

v0.1.0

parent 92c1ecd0
/// Text Generation Inference Webserver
mod batcher; mod batcher;
mod db; mod db;
mod validation;
pub mod server; pub mod server;
mod validation;
use db::{Db, Entry};
use batcher::Batcher; use batcher::Batcher;
use db::{Db, Entry};
use serde::{Deserialize, Serialize};
use validation::Validation; use validation::Validation;
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateParameters {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: i32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
}
#[derive(Serialize)]
pub(crate) struct GeneratedText {
pub generated_text: String,
}
pub(crate) type GenerateResponse = Vec<GeneratedText>;
/// Text Generation Inference webserver entrypoint
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use clap::Parser;
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)] #[clap(author, version, about, long_about = None)]
struct Args { struct Args {
#[clap(default_value = "32", long, short, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "1000", long, env)]
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize, max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "3000", long, short, env)] #[clap(default_value = "3000", long, short, env)]
port: u16, port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)] #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
shard_uds_path: String, master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)] #[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String, tokenizer_name: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests,
max_input_length,
max_batch_size, max_batch_size,
max_waiting_time,
port, port,
shard_uds_path, master_shard_uds_path,
tokenizer_name, tokenizer_name,
validation_workers,
} = args; } = args;
if validation_workers == 1 {
panic!("validation_workers must be > 0");
}
let max_waiting_time = Duration::from_secs(max_waiting_time);
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap(); let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()
...@@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> { ...@@ -39,18 +63,32 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async { .block_on(async {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let sharded_client = ShardedClient::connect_uds(shard_uds_path) // Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client sharded_client
.clear_cache() .clear_cache()
.await .await
.expect("Unable to clear cache"); .expect("Unable to clear cache");
tracing::info!("Connected"); tracing::info!("Connected");
// Binds on localhost
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
server::run(max_batch_size, sharded_client, tokenizer, addr).await; // Run server
server::run(
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
sharded_client,
tokenizer,
validation_workers,
addr,
)
.await;
Ok(()) Ok(())
}) })
} }
use crate::{Batcher, Validation}; use crate::{
Batcher, GenerateParameters, GenerateRequest, GenerateResponse, GeneratedText, Validation,
};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use bloom_inference_client::ShardedClient; use bloom_inference_client::ShardedClient;
use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::signal;
use tokio::sync::Semaphore;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::instrument; use tracing::instrument;
#[derive(Clone, Debug, Deserialize)] // Server shared state
pub(crate) struct GenerateParameters { #[derive(Clone)]
#[serde(default = "default_temperature")] struct ServerState {
pub temperature: f32, validation: Validation,
#[serde(default = "default_top_k")] batcher: Batcher,
pub top_k: i32, limit_concurrent_requests: Arc<Semaphore>,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_do_sample")]
pub do_sample: bool,
#[serde(default = "default_max_new_tokens")]
pub max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> i32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,
} }
/// Health check method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> { async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, String)> {
// TODO: while this is the best health check we can do, it is a bit on the heavy side and might
// be a bit too slow for a health check.
// What we should do instead if check if the gRPC channels are still healthy.
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
)
})?;
// Send a small inference request
state state
.batcher .batcher
.infer( .infer(
...@@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri ...@@ -82,23 +58,35 @@ async fn liveness(state: Extension<ServerState>) -> Result<(), (StatusCode, Stri
Ok(()) Ok(())
} }
/// Generate method
#[instrument(skip(state), fields(time, time_per_token))] #[instrument(skip(state), fields(time, time_per_token))]
async fn generate( async fn generate(
state: Extension<ServerState>, state: Extension<ServerState>,
req: Json<GenerateRequest>, req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> { ) -> Result<Json<GenerateResponse>, (StatusCode, String)> {
let start = Instant::now(); let start = Instant::now();
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = state.limit_concurrent_requests.try_acquire().map_err(|_| {
(
StatusCode::TOO_MANY_REQUESTS,
"Model is overloaded".to_string(),
)
})?;
// Validate request
let (input_length, validated_request) = state let (input_length, validated_request) = state
.validation .validation
// FIXME: can't we get rid of the cloning here??
.validate(GenerateRequest { .validate(GenerateRequest {
inputs: req.inputs.clone(), inputs: req.inputs.clone(),
parameters: req.parameters.clone(), parameters: req.parameters.clone(),
}) })
.await?; .await?;
// Inference
let generated_text = state.batcher.infer(input_length, validated_request).await?; let generated_text = state.batcher.infer(input_length, validated_request).await?;
// Tracing metadata
tracing::Span::current().record("time", format!("{:?}", start.elapsed())); tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record( tracing::Span::current().record(
"time_per_token", "time_per_token",
...@@ -106,31 +94,71 @@ async fn generate( ...@@ -106,31 +94,71 @@ async fn generate(
); );
tracing::info!("response: {}", generated_text); tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({ // Send response
"generated_text": generated_text, let response = vec![GeneratedText { generated_text }];
}))) Ok(Json(response))
} }
#[derive(Clone)] /// Serving method
struct ServerState { #[allow(clippy::too_many_arguments)]
validation: Validation, pub async fn run(
batcher: Batcher, max_concurrent_requests: usize,
} max_input_length: usize,
max_batch_size: usize,
pub async fn run(max_batch_size: usize, client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr) { max_waiting_time: Duration,
let batcher = Batcher::new(client, max_batch_size); client: ShardedClient,
let validation = Validation::new(tokenizer); tokenizer: Tokenizer,
validation_workers: usize,
let shared_state = ServerState { validation, batcher }; addr: SocketAddr,
) {
// Create state
let batcher = Batcher::new(client, max_batch_size, max_waiting_time);
let validation = Validation::new(validation_workers, tokenizer, max_input_length);
let shared_state = ServerState {
validation,
batcher,
limit_concurrent_requests: Arc::new(Semaphore::new(max_concurrent_requests)),
};
// Create router
let app = Router::new() let app = Router::new()
.route("/generate", post(generate)) .route("/generate", post(generate))
.layer(Extension(shared_state.clone())) .layer(Extension(shared_state.clone()))
.route("/health", get(liveness)) .route("/health", get(health))
.layer(Extension(shared_state.clone())); .layer(Extension(shared_state.clone()));
// Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await .await
.unwrap(); .unwrap();
} }
/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("signal received, starting graceful shutdown");
}
use crate::server::GenerateRequest; /// Payload validation logic
use crate::GenerateRequest;
use axum::http::StatusCode; use axum::http::StatusCode;
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::{
DecoderWrapper, ModelWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
TokenizerImpl,
};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
#[derive(Error, Debug)] /// Validation
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
/// Channel to communicate with the background validation task
sender: mpsc::Sender<ValidationRequest>, sender: mpsc::Sender<ValidationRequest>,
} }
impl Validation { impl Validation {
pub(crate) fn new(tokenizer: Tokenizer) -> Self { pub(crate) fn new(workers: usize, tokenizer: Tokenizer, max_input_length: usize) -> Self {
// Crate channel
let (validation_sender, validation_receiver) = mpsc::channel(128); let (validation_sender, validation_receiver) = mpsc::channel(128);
tokio::spawn(validation_task(tokenizer, validation_receiver)); // Launch background validation task
tokio::spawn(validation_task(
workers,
tokenizer,
max_input_length,
validation_receiver,
));
Self { Self {
sender: validation_sender, sender: validation_sender,
} }
} }
/// Validate a payload and get the number of tokens in the input
pub(crate) async fn validate( pub(crate) async fn validate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<(usize, GenerateRequest), ValidationError> { ) -> Result<(usize, GenerateRequest), ValidationError> {
// Create response channel
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
// Send request to the background validation task
// Unwrap is safe here
self.sender.send((request, sender)).await.unwrap(); self.sender.send((request, sender)).await.unwrap();
// Await on response channel
// Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
} }
} }
async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<ValidationRequest>) { /// Validation task
while let Some((request, response_tx)) = receiver.recv().await { /// Load balance the validation requests between multiple validation workers
async fn validation_task(
workers: usize,
tokenizer: Tokenizer,
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
let mut workers_senders = Vec::with_capacity(workers);
// Create workers
for _ in 0..workers {
let tokenizer_clone = tokenizer.clone();
// Create channel to communicate with worker
let (worker_sender, worker_receiver) = mpsc::channel(workers);
workers_senders.push(worker_sender);
// Spawn worker
tokio::task::spawn_blocking(move || {
validation_worker(tokenizer_clone, max_input_length, worker_receiver)
});
}
loop {
// Load balance requests between workers
for sender in workers_senders.iter() {
if let Some(validation_request) = receiver.recv().await {
sender.send(validation_request).await.unwrap();
} else {
return;
}
}
}
}
/// Check the parameters inside the payload and get the number of tokens inside the input using
/// the tokenizer
fn validation_worker(
tokenizer: TokenizerImpl<
ModelWrapper,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>,
max_input_length: usize,
mut receiver: mpsc::Receiver<ValidationRequest>,
) {
// Loop over requests
while let Some((request, response_tx)) = receiver.blocking_recv() {
if request.parameters.temperature < 0.0 { if request.parameters.temperature < 0.0 {
response_tx response_tx
.send(Err(ValidationError::Temperature)) .send(Err(ValidationError::Temperature))
...@@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali ...@@ -78,10 +121,11 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
continue; continue;
} }
// Get the number of tokens in the input
let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap(); let inputs = tokenizer.encode(request.inputs.clone(), false).unwrap();
let input_length = inputs.len(); let input_length = inputs.len();
if input_length > 1000 { if input_length > max_input_length {
response_tx response_tx
.send(Err(ValidationError::InputLength(input_length))) .send(Err(ValidationError::InputLength(input_length)))
.unwrap_or(()); .unwrap_or(());
...@@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali ...@@ -91,3 +135,28 @@ async fn validation_task(tokenizer: Tokenizer, mut receiver: mpsc::Receiver<Vali
response_tx.send(Ok((input_length, request))).unwrap_or(()); response_tx.send(Ok((input_length, request))).unwrap_or(());
} }
} }
type ValidationRequest = (
GenerateRequest,
oneshot::Sender<Result<(usize, GenerateRequest), ValidationError>>,
);
#[derive(Error, Debug)]
pub enum ValidationError {
#[error("Temperature must be strictly positive")]
Temperature,
#[error("Top p must be <= 0.0 or > 1.0")]
TopP,
#[error("Top k must be strictly positive")]
TopK,
#[error("Max New Tokens must be < 512")]
MaxNewTokens,
#[error("Inputs must have less than 1000 tokens. Given: {0}")]
InputLength(usize),
}
impl From<ValidationError> for (StatusCode, String) {
fn from(err: ValidationError) -> Self {
(StatusCode::BAD_REQUEST, err.to_string())
}
}
#!/usr/bin/env bash
server_cmd="bloom-inference-server launcher $MODEL_NAME --num-gpus $NUM_GPUS --shard-directory $MODEL_BASE_PATH"
# Run in background
$server_cmd 2>&1 > /dev/null &
# Check if server is running by checking if the unix socket is created
FILE=/tmp/bloom-inference-0
while :
do
if test -S "$FILE"; then
echo "Text Generation Python gRPC server started"
break
else
echo "Waiting for Text Generation Python gRPC server to start"
sleep 5
fi
done
sleep 1
# Run in background
text-generation-router &
# Wait for any process to exit
wait -n
# Exit with status of process that exited first
exit $?
\ No newline at end of file
# Byte-compiled / optimized / DLL files
__pycache__/
bloom_inference/__pycache__/
bloom_inference/pb/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
...@@ -4,17 +4,28 @@ gen-server: ...@@ -4,17 +4,28 @@ gen-server:
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch bloom_inference/pb/__init__.py touch bloom_inference/pb/__init__.py
unit-tests: install-transformers:
python -m pytest --cov=bloom_inference tests # Install specific version of transformers
rm transformers || true
wget https://github.com/huggingface/transformers/archive/46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
unzip 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
rm 46d37bece7d3ffdef97b1ee4a3170c0a0627d921.zip
mv transformers-46d37bece7d3ffdef97b1ee4a3170c0a0627d921 transformers
cd transformers && python setup.py install
unit-tests-reporting: install-torch:
python -m pytest --junitxml=report.xml --cov=bloom_inference tests # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir
pip-install: pip-install:
pip install grpcio-tools pip install grpcio-tools
make gen-server make gen-server
make install-torch
make install-transformers
pip install . pip install .
install: install:
poetry install poetry install
make gen-server make gen-server
\ No newline at end of file make install-torch
make install-transformers
import os
import typer import typer
from pathlib import Path from pathlib import Path
from torch.distributed.launcher import launch_agent, LaunchConfig
from typing import Optional from typing import Optional
from bloom_inference import server from bloom_inference import prepare_weights, server
app = typer.Typer() app = typer.Typer()
@app.command() @app.command()
def launcher( def serve(
model_name: str, model_name: str,
num_gpus: int = 1, sharded: bool = False,
shard_directory: Optional[Path] = None, shard_directory: Optional[Path] = None,
uds_path: Path = "/tmp/bloom-inference",
): ):
if num_gpus == 1: if sharded:
serve(model_name, False, shard_directory) assert (
shard_directory is not None
else: ), "shard_directory must be set when sharded is True"
config = LaunchConfig( assert (
min_nodes=1, os.getenv("RANK", None) is not None
max_nodes=1, ), "RANK must be set when sharded is True"
nproc_per_node=num_gpus, assert (
rdzv_backend="c10d", os.getenv("WORLD_SIZE", None) is not None
max_restarts=0, ), "WORLD_SIZE must be set when sharded is True"
) assert (
launch_agent(config, server.serve, [model_name, True, shard_directory]) os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
server.serve(model_name, sharded, uds_path, shard_directory)
@app.command() @app.command()
def serve( def prepare_weights(
model_name: str, model_name: str,
sharded: bool = False, shard_directory: Path,
shard_directory: Optional[Path] = None, cache_directory: Path,
num_shard: int = 1,
): ):
server.serve(model_name, sharded, shard_directory) prepare_weights.prepare_weights(
model_name, cache_directory, shard_directory, num_shard
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -24,6 +24,7 @@ torch.manual_seed(0) ...@@ -24,6 +24,7 @@ torch.manual_seed(0)
class Batch: class Batch:
batch_id: int batch_id: int
requests: List[generate_pb2.Request] requests: List[generate_pb2.Request]
all_input_lengths: List[int]
input_ids: Dict[str, torch.Tensor] input_ids: Dict[str, torch.Tensor]
all_input_ids: List[torch.Tensor] all_input_ids: List[torch.Tensor]
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
...@@ -46,12 +47,12 @@ class Batch: ...@@ -46,12 +47,12 @@ class Batch:
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
input_lengths = [] all_input_lengths = []
# Parse batch # Parse batch
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
input_lengths.append(r.input_length) all_input_lengths.append(r.input_length)
next_token_choosers.append( next_token_choosers.append(
NextTokenChooser( NextTokenChooser(
temperature=r.parameters.temperature, temperature=r.parameters.temperature,
...@@ -63,17 +64,12 @@ class Batch: ...@@ -63,17 +64,12 @@ class Batch:
stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens)) stopping_criterias.append(StoppingCriteria(max_new_tokens=r.max_new_tokens))
input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device) input_ids = tokenizer(inputs, return_tensors="pt", padding=True).to(device)
# Remove padding from all_input_ids all_input_ids = input_ids["input_ids"].unsqueeze(-1)
all_input_ids = [
input_ids.squeeze(0)[-length:].unsqueeze(-1)
for length, input_ids in zip(
input_lengths, input_ids["input_ids"].split(1, dim=0)
)
]
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids, input_ids=input_ids,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
...@@ -91,6 +87,7 @@ class Batch: ...@@ -91,6 +87,7 @@ class Batch:
# Batch attributes # Batch attributes
input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []} input_ids = {"input_ids": None, "attention_mask": None, "past_key_values": []}
requests = [] requests = []
all_input_lengths = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
...@@ -100,6 +97,7 @@ class Batch: ...@@ -100,6 +97,7 @@ class Batch:
start_index = 0 start_index = 0
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
all_input_lengths.extend(batch.all_input_lengths)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
...@@ -198,6 +196,7 @@ class Batch: ...@@ -198,6 +196,7 @@ class Batch:
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
all_input_lengths=all_input_lengths,
input_ids=input_ids, input_ids=input_ids,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
...@@ -227,7 +226,10 @@ class BLOOM: ...@@ -227,7 +226,10 @@ class BLOOM:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.model = ( self.model = (
AutoModelForCausalLM.from_pretrained(model_name).eval().to(self.device).to(dtype) AutoModelForCausalLM.from_pretrained(model_name)
.eval()
.to(self.device)
.to(dtype)
) )
self.num_heads = self.model.base_model.num_heads self.num_heads = self.model.base_model.num_heads
...@@ -253,6 +255,7 @@ class BLOOM: ...@@ -253,6 +255,7 @@ class BLOOM:
# New input_ids for next forward # New input_ids for next forward
next_batch_input_ids = [] next_batch_input_ids = []
next_batch_all_input_ids = [] next_batch_all_input_ids = []
next_all_input_lengths = []
next_batch_size = 0 next_batch_size = 0
next_batch_max_sequence_length = 0 next_batch_max_sequence_length = 0
...@@ -263,6 +266,7 @@ class BLOOM: ...@@ -263,6 +266,7 @@ class BLOOM:
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.all_input_lengths,
outputs.logits, outputs.logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
...@@ -272,6 +276,7 @@ class BLOOM: ...@@ -272,6 +276,7 @@ class BLOOM:
# For each member of the batch # For each member of the batch
for i, ( for i, (
request, request,
input_length,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
...@@ -302,8 +307,10 @@ class BLOOM: ...@@ -302,8 +307,10 @@ class BLOOM:
next_batch_input_ids.append(next_token) next_batch_input_ids.append(next_token)
next_batch_all_input_ids.append(all_tokens) next_batch_all_input_ids.append(all_tokens)
next_batch_size += 1 next_batch_size += 1
new_input_length = input_length + 1
next_all_input_lengths.append(new_input_length)
next_batch_max_sequence_length = max( next_batch_max_sequence_length = max(
next_batch_max_sequence_length, len(all_tokens) next_batch_max_sequence_length, new_input_length
) )
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
...@@ -350,6 +357,7 @@ class BLOOM: ...@@ -350,6 +357,7 @@ class BLOOM:
next_batch = Batch( next_batch = Batch(
batch_id=batch.batch_id, batch_id=batch.batch_id,
requests=next_batch_requests, requests=next_batch_requests,
all_input_lengths=next_all_input_lengths,
input_ids=next_batch_input_ids, input_ids=next_batch_input_ids,
all_input_ids=next_batch_all_input_ids, all_input_ids=next_batch_all_input_ids,
next_token_choosers=next_batch_next_token_choosers, next_token_choosers=next_batch_next_token_choosers,
...@@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM): ...@@ -378,7 +386,10 @@ class BLOOMSharded(BLOOM):
if self.master: if self.master:
# TODO @thomasw21 do some caching # TODO @thomasw21 do some caching
shard_state_dict_paths = prepare_weights( shard_state_dict_paths = prepare_weights(
model_name, shard_directory / "cache", shard_directory, tp_world_size=self.world_size model_name,
shard_directory / "cache",
shard_directory,
tp_world_size=self.world_size,
) )
shard_state_dict_paths = [ shard_state_dict_paths = [
str(path.absolute()) for path in shard_state_dict_paths str(path.absolute()) for path in shard_state_dict_paths
...@@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM): ...@@ -443,6 +454,7 @@ class BLOOMSharded(BLOOM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them
logits_shard = outputs.logits[:, -1, :].contiguous() logits_shard = outputs.logits[:, -1, :].contiguous()
batch_size, vocab_shard_size = logits_shard.shape batch_size, vocab_shard_size = logits_shard.shape
......
*.py *.py
*.py-e *.py-e
\ No newline at end of file
...@@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status ...@@ -14,15 +14,15 @@ from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
def match_suffix(text, suffix): def match_suffix(text, suffix):
return text[-len(suffix):] == suffix return text[-len(suffix) :] == suffix
def http_get( def http_get(
url: str, url: str,
temp_file: BinaryIO, temp_file: BinaryIO,
*, *,
timeout=10.0, timeout=10.0,
max_retries=0, max_retries=0,
): ):
""" """
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub. Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
...@@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path): ...@@ -54,7 +54,9 @@ def cache_download_url(url: str, root_dir: Path):
return filename return filename
def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world_size: int): def prepare_weights(
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
):
save_paths = [ save_paths = [
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty" save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
for tp_rank in range(tp_world_size) for tp_rank in range(tp_world_size)
...@@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -68,6 +70,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
if model_name == "bigscience/bloom-560m": if model_name == "bigscience/bloom-560m":
url = hf_hub_url(model_name, filename="pytorch_model.bin") url = hf_hub_url(model_name, filename="pytorch_model.bin")
cache_download_url(url, cache_path) cache_download_url(url, cache_path)
elif model_name == "bigscience/bloom": elif model_name == "bigscience/bloom":
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json") url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
index_path = cache_download_url(url, cache_path) index_path = cache_download_url(url, cache_path)
...@@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -75,10 +78,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
index = json.load(f) index = json.load(f)
# Get unique file names # Get unique file names
weight_files = list(set([filename for filename in index["weight_map"].values()])) weight_files = list(
set([filename for filename in index["weight_map"].values()])
)
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files] urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
Parallel(n_jobs=5)(delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)) Parallel(n_jobs=5)(
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
)
else: else:
raise ValueError(f"Unknown model name: {model_name}") raise ValueError(f"Unknown model name: {model_name}")
...@@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -91,14 +98,14 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
for state_name in keys: for state_name in keys:
state = state_dict[state_name] state = state_dict[state_name]
if any( if any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
for candidate in [ for candidate in [
"self_attention.query_key_value.weight", "self_attention.query_key_value.weight",
"self_attention.query_key_value.bias", "self_attention.query_key_value.bias",
"mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.weight",
"mlp.dense_h_to_4h.bias", "mlp.dense_h_to_4h.bias",
"word_embeddings.weight", "word_embeddings.weight",
] ]
): ):
output_size = state.shape[0] output_size = state.shape[0]
assert output_size % tp_world_size == 0 assert output_size % tp_world_size == 0
...@@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -107,7 +114,9 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif match_suffix(state_name, "lm_head.weight"): elif match_suffix(state_name, "lm_head.weight"):
output_size = state.shape[0] output_size = state.shape[0]
...@@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -120,11 +129,11 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
shards_state_dicts[tp_rank][state_name] = shard.detach().clone() shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
elif any( elif any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
for candidate in [ for candidate in [
"self_attention.dense.weight", "self_attention.dense.weight",
"mlp.dense_4h_to_h.weight", "mlp.dense_4h_to_h.weight",
] ]
): ):
input_size = state.shape[1] input_size = state.shape[1]
assert input_size % tp_world_size == 0 assert input_size % tp_world_size == 0
...@@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -132,23 +141,31 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
sharded_weights = torch.split(state, block_size, dim=1) sharded_weights = torch.split(state, block_size, dim=1)
assert len(sharded_weights) == tp_world_size assert len(sharded_weights) == tp_world_size
for tp_rank, shard in enumerate(sharded_weights): for tp_rank, shard in enumerate(sharded_weights):
shards_state_dicts[tp_rank]["transformer." + state_name] = shard.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = shard.detach().clone()
elif any( elif any(
match_suffix(state_name, candidate) match_suffix(state_name, candidate)
for candidate in [ for candidate in [
"self_attention.dense.bias", "self_attention.dense.bias",
"mlp.dense_4h_to_h.bias", "mlp.dense_4h_to_h.bias",
] ]
): ):
shards_state_dicts[0]["transformer." + state_name] = state.detach().clone() shards_state_dicts[0][
"transformer." + state_name
] = state.detach().clone()
for tp_rank in range(1, tp_world_size): for tp_rank in range(1, tp_world_size):
shards_state_dicts[tp_rank]["transformer." + state_name] = torch.zeros_like(state) shards_state_dicts[tp_rank][
"transformer." + state_name
] = torch.zeros_like(state)
else: else:
# We duplicate parameters across tp ranks # We duplicate parameters across tp ranks
for tp_rank in range(tp_world_size): for tp_rank in range(tp_world_size):
shards_state_dicts[tp_rank]["transformer." + state_name] = state.detach().clone() shards_state_dicts[tp_rank][
"transformer." + state_name
] = state.detach().clone()
del state_dict[state_name] # delete key from state_dict del state_dict[state_name] # delete key from state_dict
del state # delete tensor del state # delete tensor
...@@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -156,7 +173,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
# we save state_dict # we save state_dict
for tp_rank, (save_path, shard_state_dict) in enumerate( for tp_rank, (save_path, shard_state_dict) in enumerate(
zip(save_paths, shards_state_dicts) zip(save_paths, shards_state_dicts)
): ):
save_paths.append(save_path) save_paths.append(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
...@@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world ...@@ -166,17 +183,3 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
torch.save(shard_state_dict, save_path) torch.save(shard_state_dict, save_path)
return save_paths return save_paths
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--model-name", required=True, type=str)
parser.add_argument("--cache-path", required=True, type=str)
parser.add_argument("--save-path", required=True, type=str)
parser.add_argument("--world-size", required=True, type=int)
args = parser.parse_args()
prepare_weights(args.model_name, Path(args.cache_path), Path(args.save_path), args.world_size)
...@@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -64,70 +64,31 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
) )
async def GenerateUntilFinished(self, request, context):
batch = Batch.from_pb(request.batch, self.model.tokenizer, self.model.device)
generated_texts = []
while not generated_texts:
generated_texts, next_batch = self.model.generate_token(batch)
batch = next_batch
self.cache.set(next_batch)
return generate_pb2.GenerateUntilFinishedResponse(
generated_texts=[
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
)
async def GenerateUntilFinishedWithCache(self, request, context):
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
batches = []
for batch_pb in request.batches:
batch = self.cache.pop(batch_pb.id)
if batch is None:
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
if len(batches) > 1:
batch = Batch.concatenate(batches)
else:
batch = batches[0]
generated_texts = []
while not generated_texts:
generated_texts, next_batch = self.model.generate_token(batch)
batch = next_batch
self.cache.set(next_batch)
return generate_pb2.GenerateUntilFinishedWithCacheResponse(
generated_texts=[
generated_text.to_pb() for generated_text in generated_texts
],
batch=next_batch.to_pb() if next_batch else None,
)
def serve(model_name, sharded, shard_directory): def serve(
model_name: str,
sharded: bool,
uds_path: Path,
shard_directory: Optional[Path] = None,
):
async def serve_inner( async def serve_inner(
model_name: str, model_name: str,
sharded: bool = False, sharded: bool = False,
shard_directory: Optional[Path] = None, shard_directory: Optional[Path] = None,
): ):
unix_socket_template = "unix:///tmp/bloom-inference-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
if shard_directory is None: if shard_directory is None:
raise ValueError("shard_directory must be set when sharded is True") raise ValueError("shard_directory must be set when sharded is True")
model = BLOOMSharded(model_name, shard_directory) model = BLOOMSharded(model_name, shard_directory)
server_urls = [ server_urls = [
unix_socket_template.format(rank) for rank in range(model.world_size) unix_socket_template.format(uds_path, rank)
for rank in range(model.world_size)
] ]
local_url = unix_socket_template.format(model.rank) local_url = server_urls[model.rank]
else: else:
model = BLOOM(model_name) model = BLOOM(model_name)
local_url = unix_socket_template.format(0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
server = aio.server() server = aio.server()
...@@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory): ...@@ -142,6 +103,10 @@ def serve(model_name, sharded, shard_directory):
server.add_insecure_port(local_url) server.add_insecure_port(local_url)
await server.start() await server.start()
print("Server started at {}".format(local_url)) print("Server started at {}".format(local_url))
await server.wait_for_termination() try:
await server.wait_for_termination()
except KeyboardInterrupt:
print("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_name, sharded, shard_directory)) asyncio.run(serve_inner(model_name, sharded, shard_directory))
...@@ -82,7 +82,6 @@ def initialize_torch_distributed(): ...@@ -82,7 +82,6 @@ def initialize_torch_distributed():
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=60),
init_method="tcp://localhost:6000",
) )
return torch.distributed.distributed_c10d._get_default_group(), rank, world_size return torch.distributed.distributed_c10d._get_default_group(), rank, world_size
......
...@@ -205,7 +205,7 @@ python-versions = ">=3.7" ...@@ -205,7 +205,7 @@ python-versions = ">=3.7"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "f3dc5b2420183f2e7e9257e372489409d7bd26d1dcc535fc2558ebca50c988c2" content-hash = "a4eef5f52e8d046aa883082c865b0865047f611a3240b18250487d4b6e831496"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
......
...@@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app' ...@@ -11,7 +11,6 @@ bloom-inference-server = 'bloom_inference.cli:app'
python = "^3.9" python = "^3.9"
protobuf = "^4.21.7" protobuf = "^4.21.7"
grpcio = "^1.49.1" grpcio = "^1.49.1"
torch = "^1.12.1"
typer = "^0.6.1" typer = "^0.6.1"
grpcio-reflection = "^1.49.1" grpcio-reflection = "^1.49.1"
accelerate = "^0.12.0" accelerate = "^0.12.0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment