main.rs 5.2 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
4
5
6
7
8
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
10
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use tokenizers::Tokenizer;
13
14
15
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
19
20

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
22
23
24
25
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
26
    max_batch_size: usize,
27
28
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
29
30
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
31
    #[clap(default_value = "/tmp/text-generation-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
32
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
37
38
    #[clap(long, env)]
    json_output: bool,
39
40
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
41
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
42

Olivier Dehaene's avatar
Olivier Dehaene committed
43
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
44
45
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
46
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
47
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
48
49
        max_concurrent_requests,
        max_input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
50
        max_batch_size,
51
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
52
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
53
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
54
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
        validation_workers,
56
        json_output,
57
        otlp_endpoint,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
    } = args;

60
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
62
63
64
65
66
67
        panic!("validation_workers must be > 0");
    }

    // Download and instantiate tokenizer
    // This will only be used to validate payloads
    //
    // We need to download it outside of the Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
68
    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
69

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
70
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
71
72
73
74
75
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
76
77
            init_logging(otlp_endpoint, json_output);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
            // Instantiate sharded client from the master unix socket
79
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
80
81
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
82
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
85
86
87
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
88

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
90
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
92
93
94
95
96
            // Run server
            server::run(
                max_concurrent_requests,
                max_input_length,
                max_batch_size,
97
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
98
99
100
101
102
103
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
104
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
105
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
106
}
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}