main.rs 4.71 KB
Newer Older
1
2
3
4
5
6
7
/// Text Generation Inference benchmarking tool
///
/// Inspired by the great Oha app: https://github.com/hatoo/oha
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::ShardedClient;
8
use tokenizers::{FromPretrainedParameters, Tokenizer};
9
10
11
12
13
14
15
16
17
18
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
    #[clap(short, long, env)]
    tokenizer_name: String,
19
20
    #[clap(default_value = "main", long, env)]
    revision: String,
21
22
23
24
25
26
27
28
29
30
    #[clap(short, long)]
    batch_size: Option<Vec<u32>>,
    #[clap(default_value = "10", short, long, env)]
    sequence_length: u32,
    #[clap(default_value = "8", short, long, env)]
    decode_length: u32,
    #[clap(default_value = "10", short, long, env)]
    runs: usize,
    #[clap(default_value = "1", short, long, env)]
    warmups: usize,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    #[clap(long, env)]
    temperature: Option<f32>,
    #[clap(long, env)]
    top_k: Option<u32>,
    #[clap(long, env)]
    top_p: Option<f32>,
    #[clap(long, env)]
    typical_p: Option<f32>,
    #[clap(long, env)]
    repetition_penalty: Option<f32>,
    #[clap(long, env)]
    watermark: bool,
    #[clap(long, env)]
    do_sample: bool,
45
46
47
48
49
    #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
    master_shard_uds_path: String,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
50
51
    init_logging();

52
53
54
55
56
    // Get args
    let args = Args::parse();
    // Pattern match configuration
    let Args {
        tokenizer_name,
57
        revision,
58
59
60
61
62
        batch_size,
        sequence_length,
        decode_length,
        runs,
        warmups,
63
64
65
66
67
68
69
        temperature,
        top_k,
        top_p,
        typical_p,
        repetition_penalty,
        watermark,
        do_sample,
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        master_shard_uds_path,
    } = args;

    let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);

    // Tokenizer instance
    // This will only be used to validate payloads
    tracing::info!("Loading tokenizer");
    let local_path = Path::new(&tokenizer_name);
    let tokenizer =
        if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
        {
            // Load local tokenizer
            tracing::info!("Found local tokenizer");
            Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
        } else {
86
87
88
89
90
            tracing::info!("Downloading tokenizer");

            // Parse Huggingface hub token
            let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

91
92
            // Download and instantiate tokenizer
            // We need to download it outside of the Tokio runtime
93
94
95
96
97
98
            let params = FromPretrainedParameters {
                revision,
                auth_token,
                ..Default::default()
            };
            Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        };
    tracing::info!("Tokenizer loaded");

    // Launch Tokio runtime
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
            // Instantiate sharded client from the master unix socket
            tracing::info!("Connect to model server");
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
                .await
                .expect("Could not connect to server");
            // Clear the cache; useful if the webserver rebooted
            sharded_client
                .clear_cache(None)
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");

            // Run app
            text_generation_benchmark::run(
                tokenizer_name,
                tokenizer,
                batch_size,
                sequence_length,
                decode_length,
                runs,
                warmups,
129
130
131
132
133
134
135
                temperature,
                top_k,
                top_p,
                typical_p,
                repetition_penalty,
                watermark,
                do_sample,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                sharded_client,
            )
            .await
            .unwrap();
        });
    Ok(())
}

/// Init logging using LOG_LEVEL
fn init_logging() {
    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(fmt_layer)
        .init();
}