Unverified Commit 9af45414 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add distributed tracing (#62)

parent e520d5b3
......@@ -23,6 +23,8 @@ jobs:
toolchain: 1.65.0
override: true
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Loading cache.
uses: actions/cache@v2
id: model_cache
......
This diff is collapsed.
......@@ -2,6 +2,7 @@
members = [
"router",
"router/client",
"router/grpc-metadata",
"launcher"
]
......
FROM rust:1.65 as router-builder
FROM rust:1.67 as router-builder
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
WORKDIR /usr/src
......@@ -10,7 +16,7 @@ WORKDIR /usr/src/router
RUN cargo install --path .
FROM rust:1.65 as launcher-builder
FROM rust:1.67 as launcher-builder
WORKDIR /usr/src
......
......@@ -27,6 +27,7 @@ to power LLMs api-inference widgets.
- [Docker](#docker)
- [API Documentation](#api-documentation)
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install)
- [CUDA Kernels](#cuda-kernels)
- [Run BLOOM](#run-bloom)
......@@ -46,6 +47,7 @@ to power LLMs api-inference widgets.
- Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences
- Log probabilities
- Distributed tracing with Open Telemetry
## Officially supported models
......@@ -102,6 +104,11 @@ curl 127.0.0.1:8080/generate_stream \
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
......@@ -142,6 +149,24 @@ conda create -n text-generation-inference python=3.9
conda activate text-generation-inference
```
You may also need to install Protoc.
On Linux:
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run:
```shell
......
......@@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
parameters: {
max_new_tokens: max_new_tokens,
do_sample: true,
top_p: 0.9
top_p: 0.9,
seed: 0
}
});
let params = {
......
......@@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] }
ctrlc = { version = "3.2.3", features = ["termination"] }
serde_json = "1.0.89"
clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] }
serde_json = "1.0.93"
subprocess = "0.2.9"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] }
serde = { version = "1.0.150", features = ["derive"] }
reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.152", features = ["derive"] }
......@@ -44,6 +44,8 @@ struct Args {
master_port: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
}
fn main() -> ExitCode {
......@@ -62,6 +64,7 @@ fn main() -> ExitCode {
master_addr,
master_port,
json_output,
otlp_endpoint,
} = Args::parse();
if json_output {
......@@ -99,6 +102,7 @@ fn main() -> ExitCode {
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = otlp_endpoint.clone();
thread::spawn(move || {
shard_manager(
model_id,
......@@ -109,6 +113,7 @@ fn main() -> ExitCode {
num_shard,
master_addr,
master_port,
otlp_endpoint,
status_sender,
shutdown,
shutdown_sender,
......@@ -165,7 +170,7 @@ fn main() -> ExitCode {
"--port".to_string(),
port.to_string(),
"--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path),
format!("{shard_uds_path}-0"),
"--tokenizer-name".to_string(),
model_id,
];
......@@ -174,6 +179,12 @@ fn main() -> ExitCode {
argv.push("--json-output".to_string());
}
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
argv.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint);
}
let mut webserver = match Popen::create(
&argv,
PopenConfig {
......@@ -264,12 +275,13 @@ fn shard_manager(
world_size: usize,
master_addr: String,
master_port: usize,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>,
) {
// Get UDS path
let uds_string = format!("{}-{}", uds_path, rank);
let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string);
// Clean previous runs
fs::remove_file(uds).unwrap_or_default();
......@@ -286,6 +298,7 @@ fn shard_manager(
"--json-output".to_string(),
];
// Activate tensor parallelism
if world_size > 1 {
shard_argv.push("--sharded".to_string());
}
......@@ -294,11 +307,18 @@ fn shard_manager(
shard_argv.push("--quantize".to_string())
}
// Model optional revision
if let Some(revision) = revision {
shard_argv.push("--revision".to_string());
shard_argv.push(revision)
}
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint);
}
let mut env = vec![
("RANK".into(), rank.to_string().into()),
("WORLD_SIZE".into(), world_size.to_string().into()),
......
......@@ -15,20 +15,24 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.3"
axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] }
futures = "0.3.24"
clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.26"
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1"
rand = "0.8.5"
serde = "1.0.145"
serde_json = "1.0.85"
thiserror = "1.0.37"
tokenizers = "0.13.0"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
serde = "1.0.152"
serde_json = "1.0.93"
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["json"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
......@@ -5,13 +5,15 @@ edition = "2021"
[dependencies]
futures = "^0.3"
prost = "^0.9"
grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.11"
thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] }
tonic = "^0.6"
tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.8"
tower = "^0.4"
tracing = "^0.1"
tracing-error = "^0.2"
[build-dependencies]
tonic-build = "0.6.2"
tonic-build = "0.8.4"
prost-build = "0.11.6"
......@@ -3,13 +3,17 @@ use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/pb")
.include_file("mod.rs")
.compile(&["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e));
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}
......@@ -2,8 +2,9 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*;
use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri};
use tracing::*;
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Clone)]
......@@ -38,12 +39,8 @@ impl Client {
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {});
let response = self
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await?;
let urls = response
.into_inner()
.urls
......@@ -60,11 +57,8 @@ impl Client {
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {});
self.stub
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
......@@ -72,15 +66,10 @@ impl Client {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip(self))]
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) });
let response = self
.stub
.prefill(request)
.instrument(info_span!("prefill"))
.await?
.into_inner();
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
}
......@@ -88,18 +77,13 @@ impl Client {
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip(self))]
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(DecodeRequest { batches });
let response = self
.stub
.decode(request)
.instrument(info_span!("decode"))
.await?
.into_inner();
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
}
}
......@@ -17,21 +17,25 @@ use tonic::Status;
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0:?}")]
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0:?}")]
#[error("Server error: {0}")]
Generation(String),
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
Self::Generation(err.message().to_string())
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
Self::Connection(err.to_string())
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
......
......@@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation};
use futures::future::join_all;
use futures::future::select_all;
use tonic::transport::Uri;
use tracing::instrument;
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
......@@ -38,6 +39,7 @@ impl ShardedClient {
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self
.clients
......@@ -51,6 +53,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self
.clients
......@@ -66,6 +69,7 @@ impl ShardedClient {
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<Batch>,
......
[package]
name = "grpc-metadata"
version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.8"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
use opentelemetry::global;
use opentelemetry::propagation::{Extractor, Injector};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Extract context metadata from a gRPC request's metadata
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);
impl<'a> Extractor for MetadataExtractor<'a> {
/// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|metadata| metadata.to_str().ok())
}
/// Collect all the keys from the MetadataMap.
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|key| match key {
tonic::metadata::KeyRef::Ascii(v) => v.as_str(),
tonic::metadata::KeyRef::Binary(v) => v.as_str(),
})
.collect::<Vec<_>>()
}
}
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = value.parse() {
self.0.insert(key, val);
}
}
}
}
/// Get a context from the global context and inject the span into a gRPC request's metadata.
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::Span::current().context(),
&mut MetadataInjector(metadata),
)
})
}
pub trait InjectTelemetryContext {
fn inject_context(self) -> Self;
}
impl<T> InjectTelemetryContext for tonic::Request<T> {
fn inject_context(mut self) -> Self {
inject(self.metadata_mut());
self
}
}
......@@ -13,7 +13,7 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
use tracing::{info_span, instrument, Instrument, Span};
/// Inference struct
#[derive(Clone)]
......@@ -69,13 +69,21 @@ impl Infer {
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?;
let permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
tracing::error!("{err}");
err
})?;
// Validate request
let valid_request = self.validation.validate(request).await?;
......@@ -87,7 +95,9 @@ impl Infer {
self.queue.append(Entry {
request: valid_request,
response_tx,
time: Instant::now(),
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
});
......@@ -101,6 +111,7 @@ impl Infer {
}
/// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate(
&self,
request: GenerateRequest,
......@@ -160,7 +171,9 @@ impl Infer {
start,
})
} else {
Err(InferError::IncompleteGeneration)
let err = InferError::IncompleteGeneration;
tracing::error!("{err}");
Err(err)
}
}
}
......@@ -169,7 +182,6 @@ impl Infer {
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[instrument(skip(client, queue, shared))]
async fn batching_task(
mut client: ShardedClient,
max_batch_size: usize,
......@@ -188,8 +200,10 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await;
while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
......@@ -210,13 +224,27 @@ async fn batching_task(
};
// Try to get a new batch
if let Some((mut new_entries, new_batch)) = queue
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize)
.await
{
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await;
wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
......@@ -226,8 +254,23 @@ async fn batching_task(
}
}
}
cached_batch = wrap_future(client.decode(batches), &mut entries).await;
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
}
......@@ -235,6 +278,7 @@ async fn batching_task(
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)]
async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>,
......@@ -246,24 +290,31 @@ async fn wrap_future(
}
// If we have an error, we discard the whole batch
Err(err) => {
send_error(err, entries);
send_errors(err, entries);
None
}
}
}
/// Send errors to Infer for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(InferError::GenerationError(error.to_string())))
.send(Err(err))
.unwrap_or(());
});
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
// Get entry
......@@ -272,6 +323,9 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.get(&generation.request_id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message
// unwrap_or is valid here as we don't care if the receiver is gone.
......@@ -302,7 +356,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.send(Ok(InferStreamResponse::End {
token,
generated_text,
queued: entry.time,
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))
.unwrap_or(());
......
/// Text Generation Inference webserver entrypoint
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
......@@ -27,6 +36,8 @@ struct Args {
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
}
fn main() -> Result<(), std::io::Error> {
......@@ -43,14 +54,9 @@ fn main() -> Result<(), std::io::Error> {
tokenizer_name,
validation_workers,
json_output,
otlp_endpoint,
} = args;
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
if validation_workers == 0 {
panic!("validation_workers must be > 0");
}
......@@ -67,6 +73,8 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
init_logging(otlp_endpoint, json_output);
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
......@@ -96,3 +104,58 @@ fn main() -> Result<(), std::io::Error> {
Ok(())
})
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"text-generation-inference.router",
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
axum_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
......@@ -7,6 +7,7 @@ use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
/// Queue entry
#[derive(Debug)]
......@@ -15,8 +16,12 @@ pub(crate) struct Entry {
pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Instant when this entry was created
pub time: Instant,
/// Span that will live as long as entry
pub span: Span,
/// Temporary span used as a guard when logging inference, wait times...
pub temp_span: Option<Span>,
/// Instant when this entry was queued
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Permit
......@@ -42,13 +47,17 @@ impl Queue {
}
/// Append an entry to the queue
#[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state
// Unwrap is safe here
self.queue_sender.send(QueueCommand::Append(entry)).unwrap();
self.queue_sender
.send(QueueCommand::Append(entry, Span::current()))
.unwrap();
}
// Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
......@@ -63,6 +72,7 @@ impl Queue {
min_size,
max_size,
response_sender,
span: Span::current(),
})
.unwrap();
// Await on response channel
......@@ -77,15 +87,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
while let Some(cmd) = receiver.recv().await {
match cmd {
QueueCommand::Append(entry) => state.append(entry),
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch {
min_size,
max_size,
response_sender,
} => {
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size);
response_sender.send(next_batch).unwrap_or(());
}
}),
}
}
}
......@@ -113,7 +124,12 @@ impl State {
}
/// Append an entry to the queue
fn append(&mut self, entry: Entry) {
fn append(&mut self, mut entry: Entry) {
// Create a span that will live as long as the entry is in the queue waiting to be batched
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry));
self.next_id += 1;
}
......@@ -133,6 +149,10 @@ impl State {
let next_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size);
let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
......@@ -141,6 +161,14 @@ impl State {
self.entries
.drain(..next_batch_size)
.for_each(|(id, mut entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request {
id,
inputs: entry.request.inputs.clone(),
......@@ -162,19 +190,20 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
Some((batch_entries, batch))
Some((batch_entries, batch, next_batch_span))
}
}
type NextBatch = (IntMap<u64, Entry>, Batch);
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
enum QueueCommand {
Append(Entry),
Append(Entry, Span),
NextBatch {
min_size: Option<usize>,
max_size: usize,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
},
}
......@@ -184,6 +213,7 @@ mod tests {
use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::{mpsc, Semaphore};
use tracing::info_span;
fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1));
......@@ -208,7 +238,9 @@ mod tests {
},
},
response_tx,
time: Instant::now(),
span: info_span!("entry"),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
_permit: permit,
}
......@@ -244,7 +276,7 @@ mod tests {
state.append(default_entry());
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 2).unwrap();
let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
......@@ -273,7 +305,7 @@ mod tests {
state.append(default_entry());
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 1).unwrap();
let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
......@@ -285,7 +317,7 @@ mod tests {
state.append(default_entry());
let (entries, batch) = state.next_batch(None, 3).unwrap();
let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
......@@ -317,7 +349,7 @@ mod tests {
queue.append(default_entry());
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 2).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1));
......@@ -337,7 +369,7 @@ mod tests {
queue.append(default_entry());
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 1).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0);
......@@ -345,7 +377,7 @@ mod tests {
queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 3).await.unwrap();
let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2));
......
......@@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream;
use std::convert::Infallible;
use std::net::SocketAddr;
......@@ -18,7 +19,7 @@ use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tracing::instrument;
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
......@@ -75,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
queue_time,
inference_time,
time_per_token,
seed
seed,
)
)]
async fn generate(
......@@ -87,10 +88,7 @@ async fn generate(
// Inference
let details = req.0.parameters.details;
let response = infer.generate(req.0).await.map_err(|err| {
tracing::error!("{}", err.to_string());
err
})?;
let response = infer.generate(req.0).await?;
// Token details
let details = match details {
......@@ -135,11 +133,11 @@ async fn generate(
);
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text);
......@@ -181,7 +179,8 @@ async fn generate(
validation_time,
queue_time,
inference_time,
time_per_token
time_per_token,
seed,
)
)]
async fn generate_stream(
......@@ -197,7 +196,7 @@ async fn generate_stream(
let mut error = false;
let details = req.0.parameters.details;
match infer.generate_stream(req.0).await {
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => {
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
......@@ -243,13 +242,11 @@ async fn generate_stream(
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span
.record("validation_time", format!("{:?}", validation_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span
.record("inference_time", format!("{:?}", inference_time));
span
.record("time_per_token", format!("{:?}", time_per_token));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamResponse
......@@ -264,19 +261,17 @@ async fn generate_stream(
}
}
}
// Trace and yield error
// yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
}
},
// Trace and yield error
// yield error
Err(err) => {
error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err))
}
}
......@@ -284,7 +279,7 @@ async fn generate_stream(
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
tracing::error!("{}", err.to_string());
tracing::error!("{err}");
yield Ok(Event::from(err))
}
};
......@@ -355,7 +350,8 @@ pub async fn run(
.route("/generate_stream", post(generate_stream))
.route("/", get(health))
.route("/health", get(health))
.layer(Extension(infer));
.layer(Extension(infer))
.layer(opentelemetry_tracing_layer());
// Run server
axum::Server::bind(&addr)
......@@ -391,6 +387,7 @@ async fn shutdown_signal() {
}
tracing::info!("signal received, starting graceful shutdown");
opentelemetry::global::shutdown_tracer_provider();
}
impl From<i32> for FinishReason {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment