main.rs 2.79 KB
Newer Older
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1
/// Text Generation Inference webserver entrypoint
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
2
use bloom_inference_client::ShardedClient;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
3
use clap::Parser;
Olivier Dehaene's avatar
Olivier Dehaene committed
4
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
Olivier Dehaene's avatar
Olivier Dehaene committed
5
use text_generation_router::server;
Olivier Dehaene's avatar
Olivier Dehaene committed
6
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
Olivier Dehaene committed
7
8
9
10
11

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
12
13
14
15
16
    #[clap(default_value = "128", long, env)]
    max_concurrent_requests: usize,
    #[clap(default_value = "1000", long, env)]
    max_input_length: usize,
    #[clap(default_value = "32", long, env)]
Olivier Dehaene's avatar
Olivier Dehaene committed
17
    max_batch_size: usize,
18
19
    #[clap(default_value = "20", long, env)]
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
20
21
22
    #[clap(default_value = "3000", long, short, env)]
    port: u16,
    #[clap(default_value = "/tmp/bloom-inference-0", long, env)]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
23
    master_shard_uds_path: String,
Olivier Dehaene's avatar
Olivier Dehaene committed
24
25
    #[clap(default_value = "bigscience/bloom", long, env)]
    tokenizer_name: String,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
26
27
    #[clap(default_value = "2", long, env)]
    validation_workers: usize,
Olivier Dehaene's avatar
Olivier Dehaene committed
28
}
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
29

Olivier Dehaene's avatar
Olivier Dehaene committed
30
fn main() -> Result<(), std::io::Error> {
Olivier Dehaene's avatar
Olivier Dehaene committed
31
32
    // Get args
    let args = Args::parse();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
    // Pattern match configuration
Olivier Dehaene's avatar
Olivier Dehaene committed
34
    let Args {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
35
36
        max_concurrent_requests,
        max_input_length,
Olivier Dehaene's avatar
Olivier Dehaene committed
37
        max_batch_size,
38
        max_waiting_tokens,
Olivier Dehaene's avatar
Olivier Dehaene committed
39
        port,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
40
        master_shard_uds_path,
Olivier Dehaene's avatar
Olivier Dehaene committed
41
        tokenizer_name,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
42
        validation_workers,
Olivier Dehaene's avatar
Olivier Dehaene committed
43
44
    } = args;

45
46
    tracing_subscriber::fmt().compact().with_ansi(false).init();

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
47
48
49
50
51
52
53
54
    if validation_workers == 1 {
        panic!("validation_workers must be > 0");
    }

    // Download and instantiate tokenizer
    // This will only be used to validate payloads
    //
    // We need to download it outside of the Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
55
    let tokenizer = Tokenizer::from_pretrained(tokenizer_name, None).unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
56

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
    // Launch Tokio runtime
Olivier Dehaene's avatar
Olivier Dehaene committed
58
59
60
61
62
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
63
64
            // Instantiate sharded client from the master unix socket
            let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
Olivier Dehaene's avatar
Olivier Dehaene committed
65
66
                .await
                .expect("Could not connect to server");
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
67
            // Clear the cache; useful if the webserver rebooted
Olivier Dehaene's avatar
Olivier Dehaene committed
68
69
70
71
72
            sharded_client
                .clear_cache()
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
73

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
74
            // Binds on localhost
Olivier Dehaene's avatar
Olivier Dehaene committed
75
            let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
76

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
77
78
79
80
81
            // Run server
            server::run(
                max_concurrent_requests,
                max_input_length,
                max_batch_size,
82
                max_waiting_tokens,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
83
84
85
86
87
88
                sharded_client,
                tokenizer,
                validation_workers,
                addr,
            )
            .await;
Olivier Dehaene's avatar
Olivier Dehaene committed
89
            Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
90
        })
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
91
}