batched-bench.cpp 6.98 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#include "arg.h"
xuxzh1's avatar
init  
xuxzh1 committed
2
#include "common.h"
xuxzh1's avatar
update  
xuxzh1 committed
3
#include "log.h"
xuxzh1's avatar
init  
xuxzh1 committed
4
5
6
7
8
9
10
#include "llama.h"

#include <algorithm>
#include <cstdio>
#include <string>
#include <vector>

xuxzh1's avatar
update  
xuxzh1 committed
11
12
13
14
static void print_usage(int, char ** argv) {
    LOG("\nexample usage:\n");
    LOG("\n    %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]);
    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
15
16
17
}

int main(int argc, char ** argv) {
xuxzh1's avatar
update  
xuxzh1 committed
18
    common_params params;
xuxzh1's avatar
init  
xuxzh1 committed
19

xuxzh1's avatar
update  
xuxzh1 committed
20
    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
xuxzh1's avatar
init  
xuxzh1 committed
21
22
23
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
24
25
    common_init();

xuxzh1's avatar
init  
xuxzh1 committed
26
27
28
29
30
31
32
33
34
35
36
37
38
    int is_pp_shared = params.is_pp_shared;

    std::vector<int> n_pp = params.n_pp;
    std::vector<int> n_tg = params.n_tg;
    std::vector<int> n_pl = params.n_pl;

    // init LLM

    llama_backend_init();
    llama_numa_init(params.numa);

    // initialize the model

xuxzh1's avatar
update  
xuxzh1 committed
39
    llama_model_params model_params = common_model_params_to_llama(params);
xuxzh1's avatar
init  
xuxzh1 committed
40
41
42
43
44
45
46
47

    llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

    if (model == NULL) {
        fprintf(stderr , "%s: error: unable to load model\n" , __func__);
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
48
    llama_context_params ctx_params = common_context_params_to_llama(params);
xuxzh1's avatar
init  
xuxzh1 committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    // ensure enough sequences are available
    ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());

    llama_context * ctx = llama_new_context_with_model(model, ctx_params);

    if (ctx == NULL) {
        fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
        return 1;
    }

    const int32_t n_kv_max = llama_n_ctx(ctx);

    llama_batch batch = llama_batch_init(n_kv_max, 0, 1);

    // decode in batches of ctx_params.n_batch tokens
    auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

            llama_batch batch_view = {
                n_tokens,
                batch.token    + i,
                nullptr,
                batch.pos      + i,
                batch.n_seq_id + i,
                batch.seq_id   + i,
                batch.logits   + i,
            };

            const int ret = llama_decode(ctx, batch_view);
            if (ret != 0) {
xuxzh1's avatar
update  
xuxzh1 committed
81
                LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
xuxzh1's avatar
init  
xuxzh1 committed
82
83
84
85
86
87
88
89
90
91
92
93
                return false;
            }

            llama_synchronize(ctx);
        }

        return true;
    };

    // warm up
    {
        for (int i = 0; i < 16; ++i) {
xuxzh1's avatar
update  
xuxzh1 committed
94
            common_batch_add(batch, 0, i, { 0 }, false);
xuxzh1's avatar
init  
xuxzh1 committed
95
96
97
        }

        if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
xuxzh1's avatar
update  
xuxzh1 committed
98
            LOG_ERR("%s: llama_decode() failed\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
99
100
101
102
            return 1;
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
103
104
105
106
107
108
109
    if (!params.batched_bench_output_jsonl) {
        LOG("\n");
        LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
        LOG("\n");
        LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
        LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
    }
xuxzh1's avatar
init  
xuxzh1 committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    for (        int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
        for (    int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
            for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
                const int pp = n_pp[i_pp];
                const int tg = n_tg[i_tg];
                const int pl = n_pl[i_pl];

                const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);

                if (n_ctx_req > n_kv_max) {
                    continue;
                }

xuxzh1's avatar
update  
xuxzh1 committed
124
                common_batch_clear(batch);
xuxzh1's avatar
init  
xuxzh1 committed
125
126
127

                for (int i = 0; i < pp; ++i) {
                    for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
xuxzh1's avatar
update  
xuxzh1 committed
128
                        common_batch_add(batch, 0, i, { j }, false);
xuxzh1's avatar
init  
xuxzh1 committed
129
130
131
132
133
134
135
136
137
                    }
                }
                batch.logits[batch.n_tokens - 1] = true;

                const auto t_pp_start = ggml_time_us();

                llama_kv_cache_clear(ctx);

                if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
xuxzh1's avatar
update  
xuxzh1 committed
138
                    LOG_ERR("%s: llama_decode() failed\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                    return 1;
                }

                if (is_pp_shared) {
                    for (int32_t i = 1; i < pl; ++i) {
                        llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
                    }
                }

                const auto t_pp_end = ggml_time_us();

                const auto t_tg_start = ggml_time_us();

                for (int i = 0; i < tg; ++i) {
xuxzh1's avatar
update  
xuxzh1 committed
153
                    common_batch_clear(batch);
xuxzh1's avatar
init  
xuxzh1 committed
154
155

                    for (int j = 0; j < pl; ++j) {
xuxzh1's avatar
update  
xuxzh1 committed
156
                        common_batch_add(batch, 0, pp + i, { j }, true);
xuxzh1's avatar
init  
xuxzh1 committed
157
158
159
                    }

                    if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
xuxzh1's avatar
update  
xuxzh1 committed
160
                        LOG_ERR("%s: llama_decode() failed\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
161
162
163
164
165
166
167
168
169
170
171
172
173
                        return 1;
                    }
                }

                const auto t_tg_end = ggml_time_us();

                const int32_t n_kv = n_ctx_req;

                const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
                const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
                const float t    = t_pp + t_tg;

                const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
xuxzh1's avatar
update  
xuxzh1 committed
174
                const float speed_tg = pl*tg / t;
xuxzh1's avatar
init  
xuxzh1 committed
175
176
                const float speed    = n_kv / t;

xuxzh1's avatar
update  
xuxzh1 committed
177
178
179
180
181
182
183
184
185
186
                if(params.batched_bench_output_jsonl) {
                    LOG(
                        "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
                        "\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n",
                        n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch,
                        pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed
                    );
                } else {
                    LOG("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
                }
xuxzh1's avatar
init  
xuxzh1 committed
187
188
189
190
            }
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
191
192
    LOG("\n");
    llama_perf_context_print(ctx);
xuxzh1's avatar
init  
xuxzh1 committed
193
194
195
196
197
198
199
200

    llama_batch_free(batch);

    llama_free(ctx);
    llama_free_model(model);

    llama_backend_free();

xuxzh1's avatar
update  
xuxzh1 committed
201
    LOG("\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
202
203
204

    return 0;
}