Unverified Commit 9af45414 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add distributed tracing (#62)

parent e520d5b3
...@@ -23,6 +23,8 @@ jobs: ...@@ -23,6 +23,8 @@ jobs:
toolchain: 1.65.0 toolchain: 1.65.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Loading cache. - name: Loading cache.
uses: actions/cache@v2 uses: actions/cache@v2
id: model_cache id: model_cache
......
This diff is collapsed.
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
members = [ members = [
"router", "router",
"router/client", "router/client",
"router/grpc-metadata",
"launcher" "launcher"
] ]
......
FROM rust:1.65 as router-builder FROM rust:1.67 as router-builder
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP
WORKDIR /usr/src WORKDIR /usr/src
...@@ -10,7 +16,7 @@ WORKDIR /usr/src/router ...@@ -10,7 +16,7 @@ WORKDIR /usr/src/router
RUN cargo install --path . RUN cargo install --path .
FROM rust:1.65 as launcher-builder FROM rust:1.67 as launcher-builder
WORKDIR /usr/src WORKDIR /usr/src
......
...@@ -27,6 +27,7 @@ to power LLMs api-inference widgets. ...@@ -27,6 +27,7 @@ to power LLMs api-inference widgets.
- [Docker](#docker) - [Docker](#docker)
- [API Documentation](#api-documentation) - [API Documentation](#api-documentation)
- [A note on Shared Memory](#a-note-on-shared-memory-shm) - [A note on Shared Memory](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install) - [Local Install](#local-install)
- [CUDA Kernels](#cuda-kernels) - [CUDA Kernels](#cuda-kernels)
- [Run BLOOM](#run-bloom) - [Run BLOOM](#run-bloom)
...@@ -46,6 +47,7 @@ to power LLMs api-inference widgets. ...@@ -46,6 +47,7 @@ to power LLMs api-inference widgets.
- Logits warpers (temperature scaling, topk, repetition penalty ...) - Logits warpers (temperature scaling, topk, repetition penalty ...)
- Stop sequences - Stop sequences
- Log probabilities - Log probabilities
- Distributed tracing with Open Telemetry
## Officially supported models ## Officially supported models
...@@ -102,6 +104,11 @@ curl 127.0.0.1:8080/generate_stream \ ...@@ -102,6 +104,11 @@ curl 127.0.0.1:8080/generate_stream \
You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route. You can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.
The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference). The Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).
### Distributed Tracing
`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature
by setting the address to an OTLP collector with the `--otlp-endpoint` argument.
### A note on Shared Memory (shm) ### A note on Shared Memory (shm)
[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by [`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by
...@@ -142,6 +149,24 @@ conda create -n text-generation-inference python=3.9 ...@@ -142,6 +149,24 @@ conda create -n text-generation-inference python=3.9
conda activate text-generation-inference conda activate text-generation-inference
``` ```
You may also need to install Protoc.
On Linux:
```shell
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP
sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc
sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP
```
On MacOS, using Homebrew:
```shell
brew install protobuf
```
Then run: Then run:
```shell ```shell
......
...@@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) { ...@@ -38,7 +38,8 @@ function sample_example(inputs, max_new_tokens, name) {
parameters: { parameters: {
max_new_tokens: max_new_tokens, max_new_tokens: max_new_tokens,
do_sample: true, do_sample: true,
top_p: 0.9 top_p: 0.9,
seed: 0
} }
}); });
let params = { let params = {
......
...@@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"] ...@@ -6,14 +6,14 @@ authors = ["Olivier Dehaene"]
description = "Text Generation Launcher" description = "Text Generation Launcher"
[dependencies] [dependencies]
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.3", features = ["termination"] } ctrlc = { version = "3.2.5", features = ["termination"] }
serde_json = "1.0.89" serde_json = "1.0.93"
subprocess = "0.2.9" subprocess = "0.2.9"
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json"] } tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies] [dev-dependencies]
float_eq = "1.0.1" float_eq = "1.0.1"
reqwest = { version = "0.11.13", features = ["blocking", "json"] } reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.150", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
...@@ -44,6 +44,8 @@ struct Args { ...@@ -44,6 +44,8 @@ struct Args {
master_port: usize, master_port: usize,
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
...@@ -62,6 +64,7 @@ fn main() -> ExitCode { ...@@ -62,6 +64,7 @@ fn main() -> ExitCode {
master_addr, master_addr,
master_port, master_port,
json_output, json_output,
otlp_endpoint,
} = Args::parse(); } = Args::parse();
if json_output { if json_output {
...@@ -99,6 +102,7 @@ fn main() -> ExitCode { ...@@ -99,6 +102,7 @@ fn main() -> ExitCode {
let status_sender = status_sender.clone(); let status_sender = status_sender.clone();
let shutdown = shutdown.clone(); let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone(); let shutdown_sender = shutdown_sender.clone();
let otlp_endpoint = otlp_endpoint.clone();
thread::spawn(move || { thread::spawn(move || {
shard_manager( shard_manager(
model_id, model_id,
...@@ -109,6 +113,7 @@ fn main() -> ExitCode { ...@@ -109,6 +113,7 @@ fn main() -> ExitCode {
num_shard, num_shard,
master_addr, master_addr,
master_port, master_port,
otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
shutdown_sender, shutdown_sender,
...@@ -165,7 +170,7 @@ fn main() -> ExitCode { ...@@ -165,7 +170,7 @@ fn main() -> ExitCode {
"--port".to_string(), "--port".to_string(),
port.to_string(), port.to_string(),
"--master-shard-uds-path".to_string(), "--master-shard-uds-path".to_string(),
format!("{}-0", shard_uds_path), format!("{shard_uds_path}-0"),
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
model_id, model_id,
]; ];
...@@ -174,6 +179,12 @@ fn main() -> ExitCode { ...@@ -174,6 +179,12 @@ fn main() -> ExitCode {
argv.push("--json-output".to_string()); argv.push("--json-output".to_string());
} }
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
argv.push("--otlp-endpoint".to_string());
argv.push(otlp_endpoint);
}
let mut webserver = match Popen::create( let mut webserver = match Popen::create(
&argv, &argv,
PopenConfig { PopenConfig {
...@@ -264,12 +275,13 @@ fn shard_manager( ...@@ -264,12 +275,13 @@ fn shard_manager(
world_size: usize, world_size: usize,
master_addr: String, master_addr: String,
master_port: usize, master_port: usize,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<Mutex<bool>>,
_shutdown_sender: mpsc::Sender<()>, _shutdown_sender: mpsc::Sender<()>,
) { ) {
// Get UDS path // Get UDS path
let uds_string = format!("{}-{}", uds_path, rank); let uds_string = format!("{uds_path}-{rank}");
let uds = Path::new(&uds_string); let uds = Path::new(&uds_string);
// Clean previous runs // Clean previous runs
fs::remove_file(uds).unwrap_or_default(); fs::remove_file(uds).unwrap_or_default();
...@@ -286,6 +298,7 @@ fn shard_manager( ...@@ -286,6 +298,7 @@ fn shard_manager(
"--json-output".to_string(), "--json-output".to_string(),
]; ];
// Activate tensor parallelism
if world_size > 1 { if world_size > 1 {
shard_argv.push("--sharded".to_string()); shard_argv.push("--sharded".to_string());
} }
...@@ -294,11 +307,18 @@ fn shard_manager( ...@@ -294,11 +307,18 @@ fn shard_manager(
shard_argv.push("--quantize".to_string()) shard_argv.push("--quantize".to_string())
} }
// Model optional revision
if let Some(revision) = revision { if let Some(revision) = revision {
shard_argv.push("--revision".to_string()); shard_argv.push("--revision".to_string());
shard_argv.push(revision) shard_argv.push(revision)
} }
// OpenTelemetry
if let Some(otlp_endpoint) = otlp_endpoint {
shard_argv.push("--otlp-endpoint".to_string());
shard_argv.push(otlp_endpoint);
}
let mut env = vec![ let mut env = vec![
("RANK".into(), rank.to_string().into()), ("RANK".into(), rank.to_string().into()),
("WORLD_SIZE".into(), world_size.to_string().into()), ("WORLD_SIZE".into(), world_size.to_string().into()),
......
...@@ -15,20 +15,24 @@ path = "src/main.rs" ...@@ -15,20 +15,24 @@ path = "src/main.rs"
[dependencies] [dependencies]
async-stream = "0.3.3" async-stream = "0.3.3"
axum = { version = "0.6.4", features = ["json"] } axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.9.0"
text-generation-client = { path = "client" } text-generation-client = { path = "client" }
clap = { version = "4.0.15", features = ["derive", "env"] } clap = { version = "4.1.4", features = ["derive", "env"] }
futures = "0.3.24" futures = "0.3.26"
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.11.0"
parking_lot = "0.12.1" parking_lot = "0.12.1"
rand = "0.8.5" rand = "0.8.5"
serde = "1.0.145" serde = "1.0.152"
serde_json = "1.0.85" serde_json = "1.0.93"
thiserror = "1.0.37" thiserror = "1.0.38"
tokenizers = "0.13.0" tokenizers = "0.13.2"
tokio = { version = "1.21.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tracing = "0.1.36" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.15", features = ["json"] } tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] } utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] } utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
...@@ -5,13 +5,15 @@ edition = "2021" ...@@ -5,13 +5,15 @@ edition = "2021"
[dependencies] [dependencies]
futures = "^0.3" futures = "^0.3"
prost = "^0.9" grpc-metadata = { path = "../grpc-metadata" }
prost = "^0.11"
thiserror = "^1.0" thiserror = "^1.0"
tokio = { version = "^1.21", features = ["sync"] } tokio = { version = "^1.25", features = ["sync"] }
tonic = "^0.6" tonic = "^0.8"
tower = "^0.4" tower = "^0.4"
tracing = "^0.1" tracing = "^0.1"
tracing-error = "^0.2" tracing-error = "^0.2"
[build-dependencies] [build-dependencies]
tonic-build = "0.6.2" tonic-build = "0.8.4"
prost-build = "0.11.6"
...@@ -3,13 +3,17 @@ use std::fs; ...@@ -3,13 +3,17 @@ use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/generate.proto"); println!("cargo:rerun-if-changed=../../proto/generate.proto");
fs::create_dir("src/pb").unwrap_or(()); fs::create_dir("src/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure() tonic_build::configure()
.build_client(true) .build_client(true)
.build_server(false) .build_server(false)
.out_dir("src/pb") .out_dir("src/pb")
.include_file("mod.rs") .include_file("mod.rs")
.compile(&["../../proto/generate.proto"], &["../../proto"]) .compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(()) Ok(())
} }
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient;
use crate::pb::generate::v1::*; use crate::pb::generate::v1::*;
use crate::Result; use crate::Result;
use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::*; use tracing::instrument;
/// Text Generation Inference gRPC client /// Text Generation Inference gRPC client
#[derive(Clone)] #[derive(Clone)]
...@@ -38,12 +39,8 @@ impl Client { ...@@ -38,12 +39,8 @@ impl Client {
/// Returns a list of uris or unix sockets of all shards /// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> { pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}); let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self let response = self.stub.service_discovery(request).await?;
.stub
.service_discovery(request)
.instrument(info_span!("service_discovery"))
.await?;
let urls = response let urls = response
.into_inner() .into_inner()
.urls .urls
...@@ -60,11 +57,8 @@ impl Client { ...@@ -60,11 +57,8 @@ impl Client {
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest {}); let request = tonic::Request::new(ClearCacheRequest {}).inject_context();
self.stub self.stub.clear_cache(request).await?;
.clear_cache(request)
.instrument(info_span!("clear_cache"))
.await?;
Ok(()) Ok(())
} }
...@@ -72,15 +66,10 @@ impl Client { ...@@ -72,15 +66,10 @@ impl Client {
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }); let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self let response = self.stub.prefill(request).await?.into_inner();
.stub
.prefill(request)
.instrument(info_span!("prefill"))
.await?
.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
...@@ -88,18 +77,13 @@ impl Client { ...@@ -88,18 +77,13 @@ impl Client {
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip(self))] #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
) -> Result<(Vec<Generation>, Option<Batch>)> { ) -> Result<(Vec<Generation>, Option<Batch>)> {
let request = tonic::Request::new(DecodeRequest { batches }); let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self let response = self.stub.decode(request).await?.into_inner();
.stub
.decode(request)
.instrument(info_span!("decode"))
.await?
.into_inner();
Ok((response.generations, response.batch)) Ok((response.generations, response.batch))
} }
} }
...@@ -17,21 +17,25 @@ use tonic::Status; ...@@ -17,21 +17,25 @@ use tonic::Status;
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0:?}")] #[error("Could not connect to Text Generation server: {0}")]
Connection(String), Connection(String),
#[error("Server error: {0:?}")] #[error("Server error: {0}")]
Generation(String), Generation(String),
} }
impl From<Status> for ClientError { impl From<Status> for ClientError {
fn from(err: Status) -> Self { fn from(err: Status) -> Self {
Self::Generation(err.message().to_string()) let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
} }
} }
impl From<transport::Error> for ClientError { impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self { fn from(err: transport::Error) -> Self {
Self::Connection(err.to_string()) let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
} }
} }
......
...@@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation}; ...@@ -4,6 +4,7 @@ use crate::{Batch, Client, Generation};
use futures::future::join_all; use futures::future::join_all;
use futures::future::select_all; use futures::future::select_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument;
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
pub struct ShardedClient { pub struct ShardedClient {
...@@ -38,6 +39,7 @@ impl ShardedClient { ...@@ -38,6 +39,7 @@ impl ShardedClient {
} }
/// Clear the past generations cache /// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self) -> Result<()> { pub async fn clear_cache(&mut self) -> Result<()> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
...@@ -51,6 +53,7 @@ impl ShardedClient { ...@@ -51,6 +53,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batch /// Returns Generation for each request in batch
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
...@@ -66,6 +69,7 @@ impl ShardedClient { ...@@ -66,6 +69,7 @@ impl ShardedClient {
/// ///
/// Returns Generation for each request in batches /// Returns Generation for each request in batches
/// and the next cached batch /// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode( pub async fn decode(
&mut self, &mut self,
batches: Vec<Batch>, batches: Vec<Batch>,
......
[package]
name = "grpc-metadata"
version = "0.1.0"
edition = "2021"
[dependencies]
opentelemetry = "0.18.0"
tonic = "^0.8"
tracing = "^0.1"
tracing-opentelemetry = "0.18.0"
//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.
//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples
use opentelemetry::global;
use opentelemetry::propagation::{Extractor, Injector};
use tracing_opentelemetry::OpenTelemetrySpanExt;
/// Extract context metadata from a gRPC request's metadata
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);
impl<'a> Extractor for MetadataExtractor<'a> {
/// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).and_then(|metadata| metadata.to_str().ok())
}
/// Collect all the keys from the MetadataMap.
fn keys(&self) -> Vec<&str> {
self.0
.keys()
.map(|key| match key {
tonic::metadata::KeyRef::Ascii(v) => v.as_str(),
tonic::metadata::KeyRef::Binary(v) => v.as_str(),
})
.collect::<Vec<_>>()
}
}
/// Inject context in the metadata of a gRPC request.
struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);
impl<'a> Injector for MetadataInjector<'a> {
/// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs
fn set(&mut self, key: &str, value: String) {
if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {
if let Ok(val) = value.parse() {
self.0.insert(key, val);
}
}
}
}
/// Get a context from the global context and inject the span into a gRPC request's metadata.
fn inject(metadata: &mut tonic::metadata::MetadataMap) {
global::get_text_map_propagator(|propagator| {
propagator.inject_context(
&tracing::Span::current().context(),
&mut MetadataInjector(metadata),
)
})
}
pub trait InjectTelemetryContext {
fn inject_context(self) -> Self;
}
impl<T> InjectTelemetryContext for tonic::Request<T> {
fn inject_context(mut self) -> Self {
inject(self.metadata_mut());
self
}
}
...@@ -13,7 +13,7 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError}; ...@@ -13,7 +13,7 @@ use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::{info_span, instrument, Instrument, Span};
/// Inference struct /// Inference struct
#[derive(Clone)] #[derive(Clone)]
...@@ -69,13 +69,21 @@ impl Infer { ...@@ -69,13 +69,21 @@ impl Infer {
} }
/// Add a new request to the queue and return a stream of InferStreamResponse /// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip(self))]
pub(crate) async fn generate_stream( pub(crate) async fn generate_stream(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore // Limit concurrent requests by acquiring a permit from the semaphore
// This permit will live as long as Entry // This permit will live as long as Entry
let permit = self.clone().limit_concurrent_requests.try_acquire_owned()?; let permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
tracing::error!("{err}");
err
})?;
// Validate request // Validate request
let valid_request = self.validation.validate(request).await?; let valid_request = self.validation.validate(request).await?;
...@@ -87,7 +95,9 @@ impl Infer { ...@@ -87,7 +95,9 @@ impl Infer {
self.queue.append(Entry { self.queue.append(Entry {
request: valid_request, request: valid_request,
response_tx, response_tx,
time: Instant::now(), span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,
}); });
...@@ -101,6 +111,7 @@ impl Infer { ...@@ -101,6 +111,7 @@ impl Infer {
} }
/// Add a new request to the queue and return a InferResponse /// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate( pub(crate) async fn generate(
&self, &self,
request: GenerateRequest, request: GenerateRequest,
...@@ -160,7 +171,9 @@ impl Infer { ...@@ -160,7 +171,9 @@ impl Infer {
start, start,
}) })
} else { } else {
Err(InferError::IncompleteGeneration) let err = InferError::IncompleteGeneration;
tracing::error!("{err}");
Err(err)
} }
} }
} }
...@@ -169,7 +182,6 @@ impl Infer { ...@@ -169,7 +182,6 @@ impl Infer {
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
/// Batches requests and sends them to the inference server /// Batches requests and sends them to the inference server
#[instrument(skip(client, queue, shared))]
async fn batching_task( async fn batching_task(
mut client: ShardedClient, mut client: ShardedClient,
max_batch_size: usize, max_batch_size: usize,
...@@ -188,8 +200,10 @@ async fn batching_task( ...@@ -188,8 +200,10 @@ async fn batching_task(
// Get the next batch from the queue // Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests // This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue // waiting in the queue
while let Some((mut entries, batch)) = queue.next_batch(None, max_batch_size).await { while let Some((mut entries, batch, span)) = queue.next_batch(None, max_batch_size).await {
let mut cached_batch = wrap_future(client.prefill(batch), &mut entries).await; let mut cached_batch = wrap_future(client.prefill(batch), &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1; let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until // We loop until we do not receive any cached batch from the inference server (== until
...@@ -210,13 +224,27 @@ async fn batching_task( ...@@ -210,13 +224,27 @@ async fn batching_task(
}; };
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_size - batch_size as usize) .next_batch(min_size, max_batch_size - batch_size as usize)
.await .await
{ {
let new_batch_size = new_batch.size;
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span =
info_span!(parent: &entry.span, "waiting", batch_size = new_batch_size);
// Add relationship
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
let new_cached_batch = let new_cached_batch =
wrap_future(client.prefill(new_batch), &mut new_entries).await; wrap_future(client.prefill(new_batch), &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter // Reset waiting counter
waiting_tokens = 1; waiting_tokens = 1;
// Extend current batch with the new batch // Extend current batch with the new batch
...@@ -226,8 +254,23 @@ async fn batching_task( ...@@ -226,8 +254,23 @@ async fn batching_task(
} }
} }
} }
// Create span for this batch to add context to inference calls
cached_batch = wrap_future(client.decode(batches), &mut entries).await; let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = wrap_future(client.decode(batches), &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1; waiting_tokens += 1;
} }
} }
...@@ -235,6 +278,7 @@ async fn batching_task( ...@@ -235,6 +278,7 @@ async fn batching_task(
} }
/// Wrap a future inside a match statement to handle errors and send the responses to Infer /// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)]
async fn wrap_future( async fn wrap_future(
future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>, future: impl Future<Output = Result<(Vec<Generation>, Option<Batch>), ClientError>>,
entries: &mut IntMap<u64, Entry>, entries: &mut IntMap<u64, Entry>,
...@@ -246,24 +290,31 @@ async fn wrap_future( ...@@ -246,24 +290,31 @@ async fn wrap_future(
} }
// If we have an error, we discard the whole batch // If we have an error, we discard the whole batch
Err(err) => { Err(err) => {
send_error(err, entries); send_errors(err, entries);
None None
} }
} }
} }
/// Send errors to Infer for all `entries` /// Send errors to Infer for all `entries`
fn send_error(error: ClientError, entries: &mut IntMap<u64, Entry>) { #[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| { entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
entry entry
.response_tx .response_tx
.send(Err(InferError::GenerationError(error.to_string()))) .send(Err(err))
.unwrap_or(()); .unwrap_or(());
}); });
} }
/// Send one or multiple `InferStreamResponse` to Infer for all `entries` /// Send one or multiple `InferStreamResponse` to Infer for all `entries`
#[instrument(skip_all)]
fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) { fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| { generations.into_iter().for_each(|generation| {
// Get entry // Get entry
...@@ -272,6 +323,9 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr ...@@ -272,6 +323,9 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.get(&generation.request_id) .get(&generation.request_id)
.expect("ID not found in entries. This is a bug."); .expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _generation_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
if let Some(prefill_tokens) = generation.prefill_tokens { if let Some(prefill_tokens) = generation.prefill_tokens {
// Send message // Send message
// unwrap_or is valid here as we don't care if the receiver is gone. // unwrap_or is valid here as we don't care if the receiver is gone.
...@@ -302,7 +356,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr ...@@ -302,7 +356,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
.send(Ok(InferStreamResponse::End { .send(Ok(InferStreamResponse::End {
token, token,
generated_text, generated_text,
queued: entry.time, queued: entry.queue_time,
start: entry.batch_time.unwrap(), start: entry.batch_time.unwrap(),
})) }))
.unwrap_or(()); .unwrap_or(());
......
/// Text Generation Inference webserver entrypoint
use clap::Parser; use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
/// Text Generation Inference webserver entrypoint
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
...@@ -27,6 +36,8 @@ struct Args { ...@@ -27,6 +36,8 @@ struct Args {
validation_workers: usize, validation_workers: usize,
#[clap(long, env)] #[clap(long, env)]
json_output: bool, json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
...@@ -43,14 +54,9 @@ fn main() -> Result<(), std::io::Error> { ...@@ -43,14 +54,9 @@ fn main() -> Result<(), std::io::Error> {
tokenizer_name, tokenizer_name,
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint,
} = args; } = args;
if json_output {
tracing_subscriber::fmt().json().init();
} else {
tracing_subscriber::fmt().compact().init();
}
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
...@@ -67,6 +73,8 @@ fn main() -> Result<(), std::io::Error> { ...@@ -67,6 +73,8 @@ fn main() -> Result<(), std::io::Error> {
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output);
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
...@@ -96,3 +104,58 @@ fn main() -> Result<(), std::io::Error> { ...@@ -96,3 +104,58 @@ fn main() -> Result<(), std::io::Error> {
Ok(()) Ok(())
}) })
} }
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
"text-generation-inference.router",
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
axum_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
...@@ -7,6 +7,7 @@ use text_generation_client::{Batch, Request}; ...@@ -7,6 +7,7 @@ use text_generation_client::{Batch, Request};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
/// Queue entry /// Queue entry
#[derive(Debug)] #[derive(Debug)]
...@@ -15,8 +16,12 @@ pub(crate) struct Entry { ...@@ -15,8 +16,12 @@ pub(crate) struct Entry {
pub request: ValidGenerateRequest, pub request: ValidGenerateRequest,
/// Response sender to communicate between the Infer struct and the batching_task /// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>, pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Instant when this entry was created /// Span that will live as long as entry
pub time: Instant, pub span: Span,
/// Temporary span used as a guard when logging inference, wait times...
pub temp_span: Option<Span>,
/// Instant when this entry was queued
pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Permit /// Permit
...@@ -42,13 +47,17 @@ impl Queue { ...@@ -42,13 +47,17 @@ impl Queue {
} }
/// Append an entry to the queue /// Append an entry to the queue
#[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
// Unwrap is safe here // Unwrap is safe here
self.queue_sender.send(QueueCommand::Append(entry)).unwrap(); self.queue_sender
.send(QueueCommand::Append(entry, Span::current()))
.unwrap();
} }
// Get the next batch // Get the next batch
#[instrument(skip(self))]
pub(crate) async fn next_batch( pub(crate) async fn next_batch(
&self, &self,
min_size: Option<usize>, min_size: Option<usize>,
...@@ -63,6 +72,7 @@ impl Queue { ...@@ -63,6 +72,7 @@ impl Queue {
min_size, min_size,
max_size, max_size,
response_sender, response_sender,
span: Span::current(),
}) })
.unwrap(); .unwrap();
// Await on response channel // Await on response channel
...@@ -77,15 +87,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) { ...@@ -77,15 +87,16 @@ async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
QueueCommand::Append(entry) => state.append(entry), QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
QueueCommand::NextBatch { QueueCommand::NextBatch {
min_size, min_size,
max_size, max_size,
response_sender, response_sender,
} => { span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, max_size); let next_batch = state.next_batch(min_size, max_size);
response_sender.send(next_batch).unwrap_or(()); response_sender.send(next_batch).unwrap_or(());
} }),
} }
} }
} }
...@@ -113,7 +124,12 @@ impl State { ...@@ -113,7 +124,12 @@ impl State {
} }
/// Append an entry to the queue /// Append an entry to the queue
fn append(&mut self, entry: Entry) { fn append(&mut self, mut entry: Entry) {
// Create a span that will live as long as the entry is in the queue waiting to be batched
let queue_span = info_span!(parent: &entry.span, "queued");
entry.temp_span = Some(queue_span);
// Push entry in the queue
self.entries.push((self.next_id, entry)); self.entries.push((self.next_id, entry));
self.next_id += 1; self.next_id += 1;
} }
...@@ -133,6 +149,10 @@ impl State { ...@@ -133,6 +149,10 @@ impl State {
let next_batch_size = min(self.entries.len(), max_size); let next_batch_size = min(self.entries.len(), max_size);
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = next_batch_size);
next_batch_span.follows_from(&Span::current());
let mut batch_requests = Vec::with_capacity(next_batch_size); let mut batch_requests = Vec::with_capacity(next_batch_size);
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(next_batch_size, BuildNoHashHasher::default());
...@@ -141,6 +161,14 @@ impl State { ...@@ -141,6 +161,14 @@ impl State {
self.entries self.entries
.drain(..next_batch_size) .drain(..next_batch_size)
.for_each(|(id, mut entry)| { .for_each(|(id, mut entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span =
info_span!(parent: &entry.span, "infer", batch_size = next_batch_size);
// Add relationship
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
batch_requests.push(Request { batch_requests.push(Request {
id, id,
inputs: entry.request.inputs.clone(), inputs: entry.request.inputs.clone(),
...@@ -162,19 +190,20 @@ impl State { ...@@ -162,19 +190,20 @@ impl State {
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
Some((batch_entries, batch)) Some((batch_entries, batch, next_batch_span))
} }
} }
type NextBatch = (IntMap<u64, Entry>, Batch); type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)] #[derive(Debug)]
enum QueueCommand { enum QueueCommand {
Append(Entry), Append(Entry, Span),
NextBatch { NextBatch {
min_size: Option<usize>, min_size: Option<usize>,
max_size: usize, max_size: usize,
response_sender: oneshot::Sender<Option<NextBatch>>, response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
}, },
} }
...@@ -184,6 +213,7 @@ mod tests { ...@@ -184,6 +213,7 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use tokio::sync::{mpsc, Semaphore}; use tokio::sync::{mpsc, Semaphore};
use tracing::info_span;
fn default_entry() -> Entry { fn default_entry() -> Entry {
let semaphore = Arc::new(Semaphore::new(1)); let semaphore = Arc::new(Semaphore::new(1));
...@@ -208,7 +238,9 @@ mod tests { ...@@ -208,7 +238,9 @@ mod tests {
}, },
}, },
response_tx, response_tx,
time: Instant::now(), span: info_span!("entry"),
temp_span: None,
queue_time: Instant::now(),
batch_time: None, batch_time: None,
_permit: permit, _permit: permit,
} }
...@@ -244,7 +276,7 @@ mod tests { ...@@ -244,7 +276,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, 2).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
...@@ -273,7 +305,7 @@ mod tests { ...@@ -273,7 +305,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, 1).unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
...@@ -285,7 +317,7 @@ mod tests { ...@@ -285,7 +317,7 @@ mod tests {
state.append(default_entry()); state.append(default_entry());
let (entries, batch) = state.next_batch(None, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, 3).unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
...@@ -317,7 +349,7 @@ mod tests { ...@@ -317,7 +349,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 2).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
...@@ -337,7 +369,7 @@ mod tests { ...@@ -337,7 +369,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 1).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
...@@ -345,7 +377,7 @@ mod tests { ...@@ -345,7 +377,7 @@ mod tests {
queue.append(default_entry()); queue.append(default_entry());
let (entries, batch) = queue.next_batch(None, 3).await.unwrap(); let (entries, batch, _) = queue.next_batch(None, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
......
...@@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse}; ...@@ -10,6 +10,7 @@ use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream; use futures::Stream;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
...@@ -18,7 +19,7 @@ use tokenizers::Tokenizer; ...@@ -18,7 +19,7 @@ use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tracing::instrument; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
...@@ -75,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe ...@@ -75,7 +76,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
queue_time, queue_time,
inference_time, inference_time,
time_per_token, time_per_token,
seed seed,
) )
)] )]
async fn generate( async fn generate(
...@@ -87,10 +88,7 @@ async fn generate( ...@@ -87,10 +88,7 @@ async fn generate(
// Inference // Inference
let details = req.0.parameters.details; let details = req.0.parameters.details;
let response = infer.generate(req.0).await.map_err(|err| { let response = infer.generate(req.0).await?;
tracing::error!("{}", err.to_string());
err
})?;
// Token details // Token details
let details = match details { let details = match details {
...@@ -135,11 +133,11 @@ async fn generate( ...@@ -135,11 +133,11 @@ async fn generate(
); );
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", response.generated_text.seed)); span.record("seed", format!("{:?}", response.generated_text.seed));
tracing::info!("Output: {}", response.generated_text.text); tracing::info!("Output: {}", response.generated_text.text);
...@@ -181,7 +179,8 @@ async fn generate( ...@@ -181,7 +179,8 @@ async fn generate(
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token time_per_token,
seed,
) )
)] )]
async fn generate_stream( async fn generate_stream(
...@@ -197,7 +196,7 @@ async fn generate_stream( ...@@ -197,7 +196,7 @@ async fn generate_stream(
let mut error = false; let mut error = false;
let details = req.0.parameters.details; let details = req.0.parameters.details;
match infer.generate_stream(req.0).await { match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
Ok(mut response_stream) => { Ok(mut response_stream) => {
// Server-Sent Event stream // Server-Sent Event stream
while let Some(response) = response_stream.next().await { while let Some(response) = response_stream.next().await {
...@@ -243,13 +242,11 @@ async fn generate_stream( ...@@ -243,13 +242,11 @@ async fn generate_stream(
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{:?}", total_time));
span span.record("validation_time", format!("{:?}", validation_time));
.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{:?}", queue_time));
span span.record("inference_time", format!("{:?}", inference_time));
.record("inference_time", format!("{:?}", inference_time)); span.record("time_per_token", format!("{:?}", time_per_token));
span span.record("seed", format!("{:?}", generated_text.seed));
.record("time_per_token", format!("{:?}", time_per_token));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
// StreamResponse // StreamResponse
...@@ -264,19 +261,17 @@ async fn generate_stream( ...@@ -264,19 +261,17 @@ async fn generate_stream(
} }
} }
} }
// Trace and yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err)) yield Ok(Event::from(err))
} }
} }
} }
}, },
// Trace and yield error // yield error
Err(err) => { Err(err) => {
error = true; error = true;
tracing::error!("{}", err.to_string());
yield Ok(Event::from(err)) yield Ok(Event::from(err))
} }
} }
...@@ -284,7 +279,7 @@ async fn generate_stream( ...@@ -284,7 +279,7 @@ async fn generate_stream(
// Skip if we already sent an error // Skip if we already sent an error
if !end_reached && !error { if !end_reached && !error {
let err = InferError::IncompleteGeneration; let err = InferError::IncompleteGeneration;
tracing::error!("{}", err.to_string()); tracing::error!("{err}");
yield Ok(Event::from(err)) yield Ok(Event::from(err))
} }
}; };
...@@ -355,7 +350,8 @@ pub async fn run( ...@@ -355,7 +350,8 @@ pub async fn run(
.route("/generate_stream", post(generate_stream)) .route("/generate_stream", post(generate_stream))
.route("/", get(health)) .route("/", get(health))
.route("/health", get(health)) .route("/health", get(health))
.layer(Extension(infer)); .layer(Extension(infer))
.layer(opentelemetry_tracing_layer());
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)
...@@ -391,6 +387,7 @@ async fn shutdown_signal() { ...@@ -391,6 +387,7 @@ async fn shutdown_signal() {
} }
tracing::info!("signal received, starting graceful shutdown"); tracing::info!("signal received, starting graceful shutdown");
opentelemetry::global::shutdown_tracer_provider();
} }
impl From<i32> for FinishReason { impl From<i32> for FinishReason {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment