main.rs 5.47 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
2
use clap::Parser;
3
4
5
6
7
8
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
9
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
10
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use tokenizers::Tokenizer;
13
14
15
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
16
17
18
19
20

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
22
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
23
24
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
25
26
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
27
28
    #[clap(default_value = "1512", long, env)]
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
29
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
30
    max_batch_size: usize,
31
32
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
33
34
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
35
    #[clap(default_value = "/tmp/text-generation-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
36
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
37
38
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
39
40
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
41
42
    #[clap(long, env)]
    json_output: bool,
43
44
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
45
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
46

Olivier Dehaene's avatar
Olivier Dehaene committed
47
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
48
49
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
50
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
51
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
52
        max_concurrent_requests,
53
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
        max_input_length,
55
        max_total_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
56
        max_batch_size,
57
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
58
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
59
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
60
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
        validation_workers,
62
        json_output,
63
        otlp_endpoint,
Olivier Dehaene's avatar
Olivier Dehaene committed
64
65
    } = args;

66
    if validation_workers == 0 {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
68
69
70
71
72
73
        panic!("validation_workers must be > 0");
    }

    // Download and instantiate tokenizer
    // This will only be used to validate payloads
    //
    // We need to download it outside of the Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
74
    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
75

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
76
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
77
78
79
80
81
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
82
83
            init_logging(otlp_endpoint, json_output);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
84
            // Instantiate sharded client from the master unix socket
85
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
86
87
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
88
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
89
90
91
92
93
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
94

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
95
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
96
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
97

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
98
99
100
            // Run server
            server::run(
                max_concurrent_requests,
101
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
102
                max_input_length,
103
                max_total_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
104
                max_batch_size,
105
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
106
107
108
109
110
111
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
112
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
113
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
114
}
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}