"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "f0f84975f4a1798654e669fe016644fd9f185118"
main.rs 14.1 KB
Newer Older
1
/// Text Generation Inference webserver entrypoint
2
use axum::http::HeaderValue;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
4
5
6
7
8
9
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
Olivier Dehaene's avatar
Olivier Dehaene committed
10
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
11
use std::path::Path;
12
use std::time::Duration;
13
use text_generation_client::{ClientError, ShardedClient};
14
use text_generation_router::{server, HubModelInfo};
15
use thiserror::Error;
16
use tokenizers::{FromPretrainedParameters, Tokenizer};
17
use tower_http::cors::AllowOrigin;
18
19
20
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
Olivier Dehaene's avatar
Olivier Dehaene committed
21
22
23
24
25

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
28
29
    #[clap(default_value = "2", long, env)]
    max_best_of: usize,
30
31
    #[clap(default_value = "4", long, env)]
    max_stop_sequences: usize,
32
    #[clap(default_value = "1024", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    max_input_length: usize,
34
    #[clap(default_value = "2048", long, env)]
35
    max_total_tokens: usize,
36
37
    #[clap(default_value = "1.2", long, env)]
    waiting_served_ratio: f32,
38
39
    #[clap(default_value = "4096", long, env)]
    max_batch_prefill_tokens: u32,
40
41
    #[clap(long, env)]
    max_batch_total_tokens: Option<u32>,
42
43
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
44
45
    #[clap(default_value = "0.0.0.0", long, env)]
    hostname: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
46
47
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
48
    #[clap(default_value = "/tmp/text-generation-server-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
49
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
50
51
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
52
53
    #[clap(long, env)]
    revision: Option<String>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
54
55
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
56
57
    #[clap(long, env)]
    json_output: bool,
58
59
    #[clap(long, env)]
    otlp_endpoint: Option<String>,
60
61
    #[clap(long, env)]
    cors_allow_origin: Option<Vec<String>>,
62
63
64
65
66
    #[clap(long, env)]
    ngrok: bool,
    #[clap(long, env)]
    ngrok_authtoken: Option<String>,
    #[clap(long, env)]
67
    ngrok_edge: Option<String>,
Olivier Dehaene's avatar
Olivier Dehaene committed
68
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
69

70
fn main() -> Result<(), RouterError> {
Olivier Dehaene's avatar
Olivier Dehaene committed
71
72
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
74
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
75
        max_concurrent_requests,
76
        max_best_of,
77
        max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
78
        max_input_length,
79
        max_total_tokens,
80
        waiting_served_ratio,
81
82
        max_batch_prefill_tokens,
        max_batch_total_tokens,
83
        max_waiting_tokens,
84
        hostname,
Olivier Dehaene's avatar
Olivier Dehaene committed
85
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
87
        tokenizer_name,
88
        revision,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
89
        validation_workers,
90
        json_output,
91
        otlp_endpoint,
92
        cors_allow_origin,
93
94
        ngrok,
        ngrok_authtoken,
95
        ngrok_edge,
Olivier Dehaene's avatar
Olivier Dehaene committed
96
97
    } = args;

98
    // Validate args
99
100
101
102
103
    if max_input_length >= max_total_tokens {
        return Err(RouterError::ArgumentValidation(
            "`max_input_length` must be < `max_total_tokens`".to_string(),
        ));
    }
104
    if max_input_length as u32 > max_batch_prefill_tokens {
105
        return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
106
    }
107

108
    if validation_workers == 0 {
109
110
111
        return Err(RouterError::ArgumentValidation(
            "`validation_workers` must be > 0".to_string(),
        ));
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
112
113
    }

114
115
116
117
118
119
120
121
122
    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
        if max_batch_prefill_tokens > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
        }
        if max_total_tokens as u32 > *max_batch_total_tokens {
            return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
        }
    }

123
124
125
126
127
128
129
130
131
132
133
    // CORS allowed origins
    // map to go inside the option and then map to parse from String to HeaderValue
    // Finally, convert to AllowOrigin
    let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
        AllowOrigin::list(
            cors_allow_origin
                .iter()
                .map(|origin| origin.parse::<HeaderValue>().unwrap()),
        )
    });

134
135
136
    // Parse Huggingface hub token
    let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

137
    // Tokenizer instance
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
138
    // This will only be used to validate payloads
139
    let local_path = Path::new(&tokenizer_name);
140
141
142
143
144
145
146
147
    let local_model = local_path.exists() && local_path.is_dir();
    let tokenizer = if local_model {
        // Load local tokenizer
        Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
    } else {
        // Download and instantiate tokenizer
        // We need to download it outside of the Tokio runtime
        let params = FromPretrainedParameters {
148
            revision: revision.clone().unwrap_or("main".to_string()),
149
            auth_token: authorization_token.clone(),
150
            ..Default::default()
151
        };
152
153
        Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok()
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
154

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
155
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
156
157
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
158
        .build()?
Olivier Dehaene's avatar
Olivier Dehaene committed
159
        .block_on(async {
OlivierDehaene's avatar
OlivierDehaene committed
160
161
            init_logging(otlp_endpoint, json_output);

162
163
164
165
166
167
168
            if tokenizer.is_none() {
                tracing::warn!(
                    "Could not find a fast tokenizer implementation for {tokenizer_name}"
                );
                tracing::warn!("Rust input length validation and truncation is disabled");
            }

169
170
            // Get Model info
            let model_info = match local_model {
171
                true => HubModelInfo {
172
173
174
175
                    model_id: tokenizer_name.clone(),
                    sha: None,
                    pipeline_tag: None,
                },
176
                false => get_model_info(&tokenizer_name, revision, authorization_token)
177
178
179
180
181
182
183
184
185
                    .await
                    .unwrap_or_else(|| {
                        tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
                        HubModelInfo {
                            model_id: tokenizer_name.to_string(),
                            sha: None,
                            pipeline_tag: None,
                        }
                    }),
186
            };
187
188

            // if pipeline-tag == text-generation we default to return_full_text = true
189
            let compat_return_full_text = match &model_info.pipeline_tag {
190
191
192
193
                None => {
                    tracing::warn!("no pipeline tag found for model {tokenizer_name}");
                    false
                }
194
                Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
195
196
            };

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
197
            // Instantiate sharded client from the master unix socket
198
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
199
                .await
200
                .map_err(RouterError::Connection)?;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
201
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
202
            sharded_client
203
                .clear_cache(None)
Olivier Dehaene's avatar
Olivier Dehaene committed
204
                .await
205
                .map_err(RouterError::Cache)?;
206
            // Get info from the shard
207
            let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
208
209
210

            // Warmup model
            tracing::info!("Warming up model");
211
212
            let max_supported_batch_total_tokens = match sharded_client
                .warmup(max_input_length as u32, max_batch_prefill_tokens)
213
                .await
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                .map_err(RouterError::Warmup)?
            {
                // Older models do not support automatic max-batch-total-tokens
                None => {
                    let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
                        16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
                    );
                    tracing::warn!("Model does not support automatic max batch total tokens");
                    max_batch_total_tokens
                }
                // Flash attention models return their max supported total tokens
                Some(max_supported_batch_total_tokens) => {
                    // Warn if user added his own max-batch-total-tokens as we will ignore it
                    if max_batch_total_tokens.is_some() {
                        tracing::warn!(
                            "`--max-batch-total-tokens` is deprecated for Flash \
                        Attention models."
                        );
                        tracing::warn!(
                            "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
                        );
                    }
                    max_supported_batch_total_tokens
                }
            };
            tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
Olivier Dehaene's avatar
Olivier Dehaene committed
240
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
241

242
243
244
245
246
247
248
            let addr = match hostname.parse() {
                Ok(ip) => SocketAddr::new(ip, port),
                Err(_) => {
                    tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
                }
            };
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
249

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
250
251
            // Run server
            server::run(
252
                model_info,
253
                shard_info,
254
                compat_return_full_text,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
255
                max_concurrent_requests,
256
                max_best_of,
257
                max_stop_sequences,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
258
                max_input_length,
259
                max_total_tokens,
260
                waiting_served_ratio,
261
                max_batch_prefill_tokens,
262
                max_supported_batch_total_tokens,
263
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
264
265
266
267
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
268
                cors_allow_origin,
269
270
                ngrok,
                ngrok_authtoken,
271
                ngrok_edge,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
272
            )
273
            .await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
274
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
275
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
276
}
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
///     - otlp_endpoint is an optional URL to an Open Telemetry collector
///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
    let mut layers = Vec::new();

    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    let fmt_layer = match json_output {
        true => fmt_layer.json().flatten_event(true).boxed(),
        false => fmt_layer.boxed(),
    };
    layers.push(fmt_layer);

    // OpenTelemetry tracing layer
    if let Some(otlp_endpoint) = otlp_endpoint {
        global::set_text_map_propagator(TraceContextPropagator::new());

        let tracer = opentelemetry_otlp::new_pipeline()
            .tracing()
            .with_exporter(
                opentelemetry_otlp::new_exporter()
                    .tonic()
                    .with_endpoint(otlp_endpoint),
            )
            .with_trace_config(
                trace::config()
                    .with_resource(Resource::new(vec![KeyValue::new(
                        "service.name",
                        "text-generation-inference.router",
                    )]))
                    .with_sampler(Sampler::AlwaysOn),
            )
            .install_batch(opentelemetry::runtime::Tokio);

        if let Ok(tracer) = tracer {
            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
            axum_tracing_opentelemetry::init_propagator().unwrap();
        };
    }

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(layers)
        .init();
}
332
333

/// get model info from the Huggingface Hub
334
335
pub async fn get_model_info(
    model_id: &str,
336
    revision: Option<String>,
337
338
    token: Option<String>,
) -> Option<HubModelInfo> {
339
340
341
342
343
344
345
346
347
    let revision = match revision {
        None => {
            tracing::warn!("`--revision` is not set");
            tracing::warn!("We strongly advise to set it to a known supported commit.");
            "main".to_string()
        }
        Some(revision) => revision,
    };

348
    let client = reqwest::Client::new();
349
    // Poor man's urlencode
350
    let revision = revision.replace('/', "%2F");
351
    let url = format!("https://huggingface.co/api/models/{model_id}/revision/{revision}");
352
    let mut builder = client.get(url).timeout(Duration::from_secs(5));
353
354
355
356
    if let Some(token) = token {
        builder = builder.bearer_auth(token);
    }

357
358
359
    let response = builder.send().await.ok()?;

    if response.status().is_success() {
360
361
362
363
364
365
366
367
368
369
370
        let hub_model_info: HubModelInfo =
            serde_json::from_str(&response.text().await.ok()?).ok()?;
        if let Some(sha) = &hub_model_info.sha {
            tracing::info!(
                "Serving revision {sha} of model {}",
                hub_model_info.model_id
            );
        }
        Some(hub_model_info)
    } else {
        None
371
    }
372
}
373
374
375

#[derive(Debug, Error)]
enum RouterError {
376
377
    #[error("Argument validation error: {0}")]
    ArgumentValidation(String),
378
379
380
381
382
383
384
385
386
387
388
389
390
    #[error("Unable to connect to the Python model shards: {0}")]
    Connection(ClientError),
    #[error("Unable to clear the Python model shards cache: {0}")]
    Cache(ClientError),
    #[error("Unable to get the Python model shards info: {0}")]
    Info(ClientError),
    #[error("Unable to warmup the Python model shards: {0}")]
    Warmup(ClientError),
    #[error("Tokio runtime failed to start: {0}")]
    Tokio(#[from] std::io::Error),
    #[error("Axum webserver failed: {0}")]
    Axum(#[from] axum::BoxError),
}