Commit 295831a4 authored by Olivier Dehaene's avatar Olivier Dehaene
Browse files

Init

parents
# BLOOM Inference
A Rust and gRPC server for BLOOM Inference.
## Install
```shell
cd server
pip install .
```
```
cd router
cargo build --release
```
## Run
```shell
python server/bloom_inference/main.py bigscience/bloom --num-gpus 8 --shard-directory /dev/shm/models
```
```shell
./router/target/release/router
```
## TODO:
- [ ] Improve model download
- Store "shardable" layers separately and layer by layer
- [ ] Add batching args to router CLI
- [ ] Add docstrings + comments everywhere as the codebase is fairly complicated
- [ ] Add tests
- [ ] Add shutdown logic in router and server
- [ ] Improve multi-processing logic in server
- [ ] Improve error handling everywhere
- [ ] Improve past key layer indexing?
\ No newline at end of file
syntax = "proto3";
package generate.v1;
service TextGeneration {
/// Service discovery
rpc ServiceDiscovery(Empty) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache(Empty) returns (Empty);
/// Generate tokens for a batch without cache
rpc Generate(Batch) returns (Response);
/// Generate tokens for a batch with cache
rpc GenerateWithCache(BatchCached) returns (Response);
}
message ServiceDiscoveryResponse {
repeated string urls = 1;
}
message LogitsWarperParameters {
float temperature = 1;
uint32 top_k = 2;
float top_p = 3;
bool do_sample = 4;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Logits Warper Parameters
LogitsWarperParameters parameters = 3;
/// Stopping criteria
uint32 max_new_tokens = 4;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
}
message BatchCached {
/// Batch ID
uint64 id = 1;
/// Request ids within cache
repeated uint64 request_ids = 2;
/// Cache IDs
repeated uint64 batch_cached_ids = 3;
/// Batch size (sum of all batch sizes)
uint32 total_batch_size = 4;
/// Max sequence length
uint32 max_sequence_length = 5;
}
message FinishedGeneration {
/// ID of the original request
uint64 id = 1;
/// Output
string output = 2;
}
message CacheEntry {
/// Cache ID; same as batch ID
uint64 id = 1;
/// Requests present in cache entry
repeated uint64 request_ids = 2;
/// Sequence length
uint32 sequence_length = 3;
}
message Response {
/// Finished requests (optional)
repeated FinishedGeneration finished = 1;
/// Cache entry (optional)
optional CacheEntry cache_entry = 2;
}
// Represent an empty message.
message Empty {}
\ No newline at end of file
This diff is collapsed.
[package]
name = "bloom-inference"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bloom-inference-client = { path = "client" }
futures = "0.3.24"
parking_lot = "0.12.1"
poem = "1.3.45"
serde = "1.0.145"
serde_json = "1.0.85"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt-multi-thread", "parking_lot", "sync"] }
tracing = "0.1.36"
tracing-subscriber = "0.3.15"
[workspace]
members = [
"client",
]
[profile.release]
debug = 1
incremental = true
lto = "off"
[package]
name = "bloom-inference-client"
version = "0.1.0"
edition = "2021"
[dependencies]
futures = "0.3.24"
#grpc-error-details = { path = "../../grpc-error-details" }
#grpc-metadata = { path = "../../grpc-metadata" }
prost = "^0.9"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["sync"] }
tonic = "^0.6"
tower = "^0.4"
tracing = "^0.1"
tracing-error = "^0.2"
[build-dependencies]
tonic-build = "0.6.2"
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
fs::create_dir("src/pb").unwrap_or(());
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/pb")
.include_file("mod.rs")
.compile(&["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
Ok(())
}
use crate::pb::generate::v1::text_generation_client::TextGenerationClient;
use crate::pb::generate::v1::*;
use crate::Result;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tower::timeout::Timeout;
use tracing::*;
/// BLOOM Inference gRPC client
#[derive(Clone)]
pub struct Client {
stub: TextGenerationClient<Timeout<Channel>>,
}
impl Client {
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
let channel = Channel::builder(uri)
.connect()
.await
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
Self {
stub: TextGenerationClient::new(timeout_channel),
}
}
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
let channel = Channel::from_shared(format!("http://[::]:50051"))
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await
.expect("Transport error");
let timeout_channel = Timeout::new(channel, timeout);
Self {
stub: TextGenerationClient::new(timeout_channel),
}
}
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(Empty {});
let response = self
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let urls = response
.into_inner()
.urls
.into_iter()
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(Empty {});
self.stub
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
Ok(())
}
#[instrument(skip(self))]
pub async fn generate(
&mut self,
request: Batch,
) -> Result<(Vec<FinishedGeneration>, Option<CacheEntry>)> {
let request = tonic::Request::new(request);
let response = self
.stub
.generate(request)
.instrument(info_span!("generate"))
.await?
.into_inner();
Ok((response.finished, response.cache_entry))
}
#[instrument(skip(self))]
pub async fn generate_with_cache(
&mut self,
request: BatchCached,
) -> Result<(Vec<FinishedGeneration>, Option<CacheEntry>)> {
let request = tonic::Request::new(request);
let response = self
.stub
.generate_with_cache(request)
.instrument(info_span!("generate_with_cache"))
.await?
.into_inner();
Ok((response.finished, response.cache_entry))
}
}
//! BLOOM Inference gRPC client library
mod client;
mod pb;
mod sharded_client;
pub use client::Client;
pub use pb::generate::v1::{
Batch, BatchCached, CacheEntry, FinishedGeneration, LogitsWarperParameters, Request,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
pub use tonic::transport::Uri;
use tonic::Status;
#[derive(Error, Debug, Clone)]
#[error("Text generation client error: {msg:?}")]
pub struct ClientError {
msg: String,
// source: Status,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
Self {
msg: err.to_string(),
// source: err,
}
}
}
pub type Result<T> = std::result::Result<T, ClientError>;
*.rs
\ No newline at end of file
use crate::Result;
use crate::{Batch, BatchCached, CacheEntry, Client, FinishedGeneration};
use futures::future::join_all;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc};
use tonic::transport::Uri;
#[derive(Clone, Debug)]
enum Command {
Generate(
Batch,
mpsc::Sender<Result<(Vec<FinishedGeneration>, Option<CacheEntry>)>>,
),
GenerateWithCache(
BatchCached,
mpsc::Sender<Result<(Vec<FinishedGeneration>, Option<CacheEntry>)>>,
),
ClearCache(mpsc::Sender<Result<()>>),
}
async fn client_task(mut client: Client, mut request_subscriber: broadcast::Receiver<Command>) {
while let Ok(message) = request_subscriber.recv().await {
match message {
Command::Generate(batch, response_tx) => {
let result = client.generate(batch).await;
response_tx.try_send(result).unwrap_or(());
}
Command::GenerateWithCache(batch_cached, response_tx) => {
let result = client.generate_with_cache(batch_cached).await;
response_tx.try_send(result).unwrap_or(());
}
Command::ClearCache(response_tx) => {
let result = client.clear_cache().await;
response_tx.try_send(result).unwrap_or(());
}
};
}
}
pub struct ShardedClient {
request_tx: broadcast::Sender<Command>,
}
impl ShardedClient {
fn new(mut clients: Vec<Client>) -> Self {
let (request_tx, _) = broadcast::channel(1);
for client in clients.drain(..) {
let request_subscriber = request_tx.subscribe();
tokio::spawn(client_task(client, request_subscriber));
}
Self { request_tx }
}
async fn from_master_client(mut master_client: Client) -> Self {
let uris = master_client.service_discovery().await.unwrap();
let futures = uris
.into_iter()
.map(|path| Client::connect_uds(path, Duration::from_secs(5)));
let clients = join_all(futures).await;
Self::new(clients)
}
/// Returns a client connected to the given url. Requests exceeding timeout will fail.
pub async fn connect(uri: Uri, timeout: Duration) -> Self {
let master_client = Client::connect(uri, timeout).await;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket. Requests exceeding timeout will fail.
pub async fn connect_uds(path: String, timeout: Duration) -> Self {
let master_client = Client::connect_uds(path, timeout).await;
Self::from_master_client(master_client).await
}
pub async fn generate(
&self,
batch: Batch,
) -> Result<(Vec<FinishedGeneration>, Option<CacheEntry>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::Generate(batch, response_tx))
.unwrap();
response_rx.recv().await.unwrap()
}
pub async fn generate_with_cache(
&self,
batch_cached: BatchCached,
) -> Result<(Vec<FinishedGeneration>, Option<CacheEntry>)> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::GenerateWithCache(batch_cached, response_tx))
.unwrap();
response_rx.recv().await.unwrap()
}
pub async fn clear_cache(&self) -> Result<()> {
let (response_tx, mut response_rx) = mpsc::channel(1);
self.request_tx
.send(Command::ClearCache(response_tx))
.unwrap();
response_rx.recv().await.unwrap()
}
}
/// This code is massively inspired by Tokio mini-redis
use crate::GenerateRequest;
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::RwLock;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::oneshot::Sender;
#[derive(Debug, Clone)]
pub(crate) struct Db {
pub shared: Arc<Shared>,
}
#[derive(Debug)]
pub struct Shared {
state: RwLock<State>,
}
#[derive(Debug)]
struct State {
entries: BTreeMap<u64, (Request, Sender<Result<String, ClientError>>)>,
/// Identifier to use for the next expiration. Each expiration is associated
/// with a unique identifier. See above for why.
next_id: u64,
next_batch_id: u64,
/// Current batch id
next_batch_start_id: u64,
}
impl Db {
pub(crate) fn new() -> Self {
let shared = Arc::new(Shared {
state: RwLock::new(State {
entries: BTreeMap::new(),
next_id: 0,
next_batch_id: 0,
next_batch_start_id: 0,
}),
});
Self { shared }
}
pub(crate) fn append(&self, request: GenerateRequest, sender: Sender<Result<String, ClientError>>) {
let mut state = self.shared.state.write();
let id = state.next_id;
state.next_id += 1;
let parameters = Some(LogitsWarperParameters {
temperature: request.parameters.temperature,
top_k: request.parameters.top_k,
top_p: request.parameters.top_p,
do_sample: request.parameters.do_sample,
});
let request = Request {
id,
inputs: request.inputs,
parameters,
max_new_tokens: request.parameters.max_new_tokens,
};
state.entries.insert(id, (request, sender));
}
pub(crate) fn remove(&self, id: &u64) -> Option<(Request, Sender<Result<String, ClientError>>)> {
let mut state = self.shared.state.write();
state.entries.remove(id)
}
pub(crate) fn len(&self) -> usize {
let state = self.shared.state.read();
state.entries.len()
}
fn next_requests(&self, max_size: usize) -> Option<(u64, Vec<Request>)> {
let state = self.shared.state.read();
let requests: Vec<Request> = state
.entries
.range(state.next_batch_start_id..)
.take(max_size)
.map(|(_, (request, _))| request.clone())
.collect();
if requests.is_empty() {
None
} else {
let last_id = requests.last().unwrap().id;
Some((last_id, requests))
}
}
pub(crate) fn next_batch(&self, max_size: usize) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
let mut state = self.shared.state.write();
let batch = Batch {
id: state.next_batch_id,
requests,
};
state.next_batch_start_id = last_id + 1;
state.next_batch_id += 1;
return Some(batch);
}
None
}
pub(crate) fn next_batch_minimum_size(
&self,
min_size: usize,
max_size: usize,
) -> Option<Batch> {
if let Some((last_id, requests)) = self.next_requests(max_size) {
if requests.len() >= min_size {
let mut state = self.shared.state.write();
let batch = Batch {
id: state.next_batch_id,
requests,
};
state.next_batch_start_id = last_id + 1;
state.next_batch_id += 1;
return Some(batch);
}
}
None
}
}
use crate::{Db, GenerateRequest};
use bloom_inference_client::{Batch, BatchCached, CacheEntry, ClientError, FinishedGeneration, ShardedClient};
use std::sync::Arc;
use tokio::sync::{oneshot, Notify};
const MAX_LENGTH: usize = 128;
pub struct InferError {}
#[derive(Clone)]
pub(crate) struct Infer {
db: Db,
shared: Arc<Shared>,
}
struct Shared {
batching_task: Notify,
}
impl Infer {
pub(crate) fn new(client: ShardedClient) -> Self {
let db = Db::new();
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
tokio::spawn(batching_task(client, db.clone(), shared.clone()));
Self { db, shared }
}
pub(crate) async fn infer(&self, request: GenerateRequest) -> Result<String, InferError> {
if self.db.len() > MAX_LENGTH {
return Err(InferError {});
}
let (request_tx, request_rx) = oneshot::channel();
self.db.append(request, request_tx);
self.shared.batching_task.notify_waiters();
match request_rx.await.unwrap() {
Ok(output) => Ok(output),
Err(_) => Err(InferError {})
}
}
}
async fn batching_task(client: ShardedClient, db: Db, shared: Arc<Shared>) {
loop {
shared.batching_task.notified().await;
if let Some(batch) = db.next_batch(32) {
let mut cache_entry = infer_batch(batch, &client, &db).await;
loop {
if let Some(entry) = cache_entry {
let mut batch_cached_ids = vec![entry.id];
let mut total_batch_size = entry.request_ids.len();
let mut max_sequence_length = entry.sequence_length;
let mut request_ids = entry.request_ids;
if total_batch_size <= 16 {
if let Some(batch) = db.next_batch_minimum_size(16, 48) {
let other_cache_entry = infer_batch(batch, &client, &db).await;
if let Some(entry) = other_cache_entry {
batch_cached_ids.push(entry.id);
total_batch_size += entry.request_ids.len();
max_sequence_length =
max_sequence_length.max(entry.sequence_length);
request_ids.extend(entry.request_ids.into_iter());
}
}
}
let batch_cached = BatchCached {
id: entry.id,
batch_cached_ids,
total_batch_size: total_batch_size as u32,
max_sequence_length,
request_ids,
};
cache_entry = infer_batch_cached(batch_cached, &client, &db).await;
} else {
break;
}
}
}
}
}
async fn infer_batch_cached(batch: BatchCached, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
match client.generate_with_cache(batch.clone()).await {
Ok((finished, cache_entry)) => {
send_finished(finished, db);
cache_entry
}
Err(err) => {
println!("{:?}", err);
send_error(err, batch.request_ids, &db);
None
}
}
}
async fn infer_batch(batch: Batch, client: &ShardedClient, db: &Db) -> Option<CacheEntry> {
match client.generate(batch.clone()).await {
Ok((finished, cache_entry)) => {
send_finished(finished, db);
cache_entry
}
Err(err) => {
println!("{:?}", err);
send_error(err, batch.requests.into_iter().map(|req| req.id).collect(), &db);
None
}
}
}
fn send_error(error: ClientError, request_ids: Vec<u64>, db: &Db) {
request_ids.into_iter().for_each(|id| {
let (_, response_tx) = db.remove(&id).unwrap();
response_tx.send(Err(error.clone())).unwrap_or(());
});
}
fn send_finished(finished: Vec<FinishedGeneration>, db: &Db) {
finished.into_iter().for_each(|output| {
let (_, response_tx) = db.remove(&output.id).unwrap();
response_tx.send(Ok(output.output)).unwrap_or(());
});
}
use tokio::time::Instant;
use poem;
use poem::middleware::AddData;
use poem::web::Data;
use poem::{handler, listener::TcpListener, post, web::Json, EndpointExt, Result, Route, Server};
use bloom_inference_client::ShardedClient;
use serde::Deserialize;
use std::time::Duration;
use poem::http::StatusCode;
use tracing::instrument;
mod db;
use db::Db;
mod infer;
use infer::Infer;
#[derive(Clone, Debug, Deserialize)]
struct GenerateParameters {
#[serde(default = "default_temperature")]
temperature: f32,
#[serde(default = "default_top_k")]
top_k: u32,
#[serde(default = "default_top_p")]
top_p: f32,
#[serde(default = "default_do_sample")]
do_sample: bool,
#[serde(default = "default_max_new_tokens")]
max_new_tokens: u32,
}
fn default_temperature() -> f32 {
1.0
}
fn default_top_k() -> u32 {
0
}
fn default_top_p() -> f32 {
1.0
}
fn default_do_sample() -> bool {
false
}
fn default_max_new_tokens() -> u32 {
20
}
#[derive(Clone, Debug, Deserialize)]
struct GenerateRequest {
inputs: String,
#[serde(default = "default_parameters")]
parameters: GenerateParameters,
}
fn default_parameters() -> GenerateParameters {
GenerateParameters {
temperature: default_temperature(),
top_k: default_top_k(),
top_p: default_top_p(),
do_sample: default_do_sample(),
max_new_tokens: default_max_new_tokens(),
}
}
#[handler]
#[instrument(skip(infer), fields(time, time_per_token))]
async fn generate(
infer: Data<&Infer>,
req: Json<GenerateRequest>,
) -> Result<Json<serde_json::Value>> {
let start = Instant::now();
let output = infer
.infer(GenerateRequest {
inputs: req.inputs.clone(),
parameters: req.parameters.clone(),
})
.await;
match output {
Ok(generated_text) => {
tracing::Span::current().record("time", format!("{:?}", start.elapsed()));
tracing::Span::current().record("time_per_token", format!("{:?}", start.elapsed() / req.parameters.max_new_tokens));
tracing::info!("response: {}", generated_text);
Ok(Json(serde_json::json!({
"generated_text": generated_text,
})))
}
Err(_) => {
Err(poem::Error::from_status(StatusCode::INTERNAL_SERVER_ERROR))
}
}
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
tracing_subscriber::fmt::init();
let sharded_client =
ShardedClient::connect_uds("/tmp/bloom-inference-0".to_string(), Duration::from_secs(5))
.await;
sharded_client
.clear_cache()
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
let infer = Infer::new(sharded_client);
let app = Route::new()
.at("/generate", post(generate))
.with(AddData::new(infer));
Server::new(TcpListener::bind("127.0.0.1:3000"))
.run(app)
.await
}
gen-server:
mkdir bloom_inference/pb || true
python -m grpc_tools.protoc -I../proto --python_out=bloom_inference/pb --grpc_python_out=bloom_inference/pb ../proto/generate.proto
find bloom_inference/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch bloom_inference/pb/__init__.py
unit-tests:
python -m pytest --cov=bloom_inference tests
unit-tests-reporting:
python -m pytest --junitxml=report.xml --cov=bloom_inference tests
pip-install:
pip install grpcio-tools
make gen-server
pip install .
install:
poetry install
make gen-server
\ No newline at end of file
# BLOOM Inference Python gRPC Server
A Python gRPC server for BLOOM Inference
## Local Install (with poetry)
```shell
make install
```
## Local Install (with pip)
```shell
make pip-install
```
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment