main.rs 7.68 KB
Newer Older
1
2
3
4
5
6
7
/// Text Generation Inference benchmarking tool
///
/// Inspired by the great Oha app: https://github.com/hatoo/oha
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::ShardedClient;
8
use tokenizers::{FromPretrainedParameters, Tokenizer};
9
10
11
12
13
14
15
16
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;

/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
17
    /// The name of the tokenizer (as in model_id on the huggingface hub, or local path).
18
19
    #[clap(short, long, env)]
    tokenizer_name: String,
20
21

    /// The revision to use for the tokenizer if on the hub.
22
23
    #[clap(default_value = "main", long, env)]
    revision: String,
24
25
26
27
28

    /// The various batch sizes to benchmark for, the idea is to get enough
    /// batching to start seeing increased latency, this usually means you're
    /// moving from memory bound (usual as BS=1) to compute bound, and this is
    /// a sweet spot for the maximum batch size for the model under test
29
30
    #[clap(short, long)]
    batch_size: Option<Vec<u32>>,
31
32
33
34
35
36
37

    /// This is the initial prompt sent to the text-generation-server length
    /// in token. Longer prompt will slow down the benchmark. Usually the
    /// latency grows somewhat linearly with this for the prefill step.
    ///
    /// Most importantly, the prefill step is usually not the one dominating
    /// your runtime, so it's ok to keep it short.
38
39
    #[clap(default_value = "10", short, long, env)]
    sequence_length: u32,
40
41
42
43
44
45

    /// This is how many tokens will be generated by the server and averaged out
    /// to give the `decode` latency. This is the *critical* number you want to optimize for
    /// LLM spend most of their time doing decoding.
    ///
    /// Decode latency is usually quite stable.
46
47
    #[clap(default_value = "8", short, long, env)]
    decode_length: u32,
48
49

    ///How many runs should we average from
50
51
    #[clap(default_value = "10", short, long, env)]
    runs: usize,
52
53

    /// Number of warmup cycles
54
55
    #[clap(default_value = "1", short, long, env)]
    warmups: usize,
56
57
58
59
60
61
62
63

    /// The location of the grpc socket. This benchmark tool bypasses the router
    /// completely and directly talks to the gRPC processes
    #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
    master_shard_uds_path: String,

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
64
65
    #[clap(long, env)]
    temperature: Option<f32>,
66
67
68

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
69
70
    #[clap(long, env)]
    top_k: Option<u32>,
71
72
73

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
74
75
    #[clap(long, env)]
    top_p: Option<f32>,
76
77
78

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
79
80
    #[clap(long, env)]
    typical_p: Option<f32>,
81
82
83

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
84
85
    #[clap(long, env)]
    repetition_penalty: Option<f32>,
86

87
88
89
90
91
    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
    #[clap(long, env)]
    frequency_penalty: Option<f32>,

92
93
    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
94
95
    #[clap(long, env)]
    watermark: bool,
96
97
98

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
99
100
    #[clap(long, env)]
    do_sample: bool,
Nicolas Patry's avatar
Nicolas Patry committed
101
102
103
104
105

    /// Generation parameter in case you want to specifically test/debug particular
    /// decoding strategies, for full doc refer to the `text-generation-server`
    #[clap(long, env)]
    top_n_tokens: Option<u32>,
106
107
108
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
109
110
    init_logging();

111
112
113
114
115
    // Get args
    let args = Args::parse();
    // Pattern match configuration
    let Args {
        tokenizer_name,
116
        revision,
117
118
119
120
121
        batch_size,
        sequence_length,
        decode_length,
        runs,
        warmups,
122
123
124
125
126
        temperature,
        top_k,
        top_p,
        typical_p,
        repetition_penalty,
127
        frequency_penalty,
128
129
        watermark,
        do_sample,
130
        master_shard_uds_path,
Nicolas Patry's avatar
Nicolas Patry committed
131
        top_n_tokens,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    } = args;

    let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);

    // Tokenizer instance
    // This will only be used to validate payloads
    tracing::info!("Loading tokenizer");
    let local_path = Path::new(&tokenizer_name);
    let tokenizer =
        if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
        {
            // Load local tokenizer
            tracing::info!("Found local tokenizer");
            Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
        } else {
147
148
149
150
151
            tracing::info!("Downloading tokenizer");

            // Parse Huggingface hub token
            let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();

152
153
            // Download and instantiate tokenizer
            // We need to download it outside of the Tokio runtime
154
155
156
157
158
159
            let params = FromPretrainedParameters {
                revision,
                auth_token,
                ..Default::default()
            };
            Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        };
    tracing::info!("Tokenizer loaded");

    // Launch Tokio runtime
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
            // Instantiate sharded client from the master unix socket
            tracing::info!("Connect to model server");
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
                .await
                .expect("Could not connect to server");
            // Clear the cache; useful if the webserver rebooted
            sharded_client
                .clear_cache(None)
                .await
                .expect("Unable to clear cache");
            tracing::info!("Connected");

            // Run app
            text_generation_benchmark::run(
                tokenizer_name,
                tokenizer,
                batch_size,
                sequence_length,
                decode_length,
Nicolas Patry's avatar
Nicolas Patry committed
188
                top_n_tokens,
189
190
                runs,
                warmups,
191
192
193
194
195
                temperature,
                top_k,
                top_p,
                typical_p,
                repetition_penalty,
196
                frequency_penalty,
197
198
                watermark,
                do_sample,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
                sharded_client,
            )
            .await
            .unwrap();
        });
    Ok(())
}

/// Init logging using LOG_LEVEL
fn init_logging() {
    // STDOUT/STDERR layer
    let fmt_layer = tracing_subscriber::fmt::layer()
        .with_file(true)
        .with_line_number(true);

    // Filter events with LOG_LEVEL
    let env_filter =
        EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));

    tracing_subscriber::registry()
        .with(env_filter)
        .with(fmt_layer)
        .init();
}