"vscode:/vscode.git/clone" did not exist on "3e3fe72299980f53262880e24e372ed7d785093c"
speculative.cpp 24.6 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#include "arg.h"
xuxzh1's avatar
init  
xuxzh1 committed
2
#include "common.h"
xuxzh1's avatar
update  
xuxzh1 committed
3
4
#include "sampling.h"
#include "log.h"
xuxzh1's avatar
init  
xuxzh1 committed
5
6
#include "llama.h"

xuxzh1's avatar
update  
xuxzh1 committed
7
#include <algorithm>
xuxzh1's avatar
init  
xuxzh1 committed
8
#include <cstdio>
xuxzh1's avatar
update  
xuxzh1 committed
9
10
11
#include <cstring>
#include <random>
#include <set>
xuxzh1's avatar
init  
xuxzh1 committed
12
13
14
#include <string>
#include <vector>

xuxzh1's avatar
update  
xuxzh1 committed
15
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
xuxzh1's avatar
init  
xuxzh1 committed
16
17
18
19
20
21
22
23
24
25
26
27
28
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5

struct seq_draft {
    bool active   = false;
    bool drafting = false;
    bool skip     = false;

    int i_batch_dft = 0;
    std::vector<int> i_batch_tgt;

    std::vector<llama_token> tokens;
    std::vector<std::vector<llama_token_data>> dists;

xuxzh1's avatar
update  
xuxzh1 committed
29
    struct common_sampler * smpl = nullptr;
xuxzh1's avatar
init  
xuxzh1 committed
30
31
32
};

int main(int argc, char ** argv) {
xuxzh1's avatar
update  
xuxzh1 committed
33
34
35
36
    common_params params;

    // needed to get candidate probs even for temp <= 0.0
    params.sampling.n_probs = 128;
xuxzh1's avatar
init  
xuxzh1 committed
37

xuxzh1's avatar
update  
xuxzh1 committed
38
    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
xuxzh1's avatar
init  
xuxzh1 committed
39
40
41
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
42
43
44
45
46
47
48
49
50
    if (params.n_predict < -1) {
        LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
        return 1;
    }

    common_init();

    if (params.speculative.model.empty()) {
        LOG_ERR("%s: --model-draft is required\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
51
52
53
54
55
56
57
        return 1;
    }

    // max number of parallel drafting sequences (i.e. tree branches)
    const int n_seq_dft = params.n_parallel;

    // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
xuxzh1's avatar
update  
xuxzh1 committed
58
    const float p_draft_split = params.speculative.p_split;
xuxzh1's avatar
init  
xuxzh1 committed
59

xuxzh1's avatar
update  
xuxzh1 committed
60
    std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
xuxzh1's avatar
init  
xuxzh1 committed
61
62
63
64
65
66
67
68
69
70
71
72
73
    std::uniform_real_distribution<> u_dist;

    // init llama.cpp
    llama_backend_init();
    llama_numa_init(params.numa);

    llama_model * model_tgt = NULL;
    llama_model * model_dft = NULL;

    llama_context * ctx_tgt = NULL;
    llama_context * ctx_dft = NULL;

    // load the target model
xuxzh1's avatar
update  
xuxzh1 committed
74
    common_init_result llama_init_tgt = common_init_from_params(params);
xuxzh1's avatar
init  
xuxzh1 committed
75
76
77
78
    model_tgt = llama_init_tgt.model;
    ctx_tgt = llama_init_tgt.context;

    // load the draft model
xuxzh1's avatar
update  
xuxzh1 committed
79
80
81
82
    params.model = params.speculative.model;
    params.n_gpu_layers = params.speculative.n_gpu_layers;
    if (params.speculative.cpuparams.n_threads > 0) {
        params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
xuxzh1's avatar
init  
xuxzh1 committed
83
    }
xuxzh1's avatar
update  
xuxzh1 committed
84
85
86

    params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
    common_init_result llama_init_dft = common_init_from_params(params);
xuxzh1's avatar
init  
xuxzh1 committed
87
88
89
90
    model_dft = llama_init_dft.model;
    ctx_dft = llama_init_dft.context;

    const bool vocab_type_tgt = llama_vocab_type(model_tgt);
xuxzh1's avatar
update  
xuxzh1 committed
91
    LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
xuxzh1's avatar
init  
xuxzh1 committed
92
93

    const bool vocab_type_dft = llama_vocab_type(model_dft);
xuxzh1's avatar
update  
xuxzh1 committed
94
    LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
xuxzh1's avatar
init  
xuxzh1 committed
95
96

    if (vocab_type_tgt != vocab_type_dft) {
xuxzh1's avatar
update  
xuxzh1 committed
97
98
        LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
        LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
xuxzh1's avatar
init  
xuxzh1 committed
99
100
101
102
103
104
105
106
107
        return 1;
    }

    if (
        llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
        llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
        llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
        llama_token_eos(model_tgt) != llama_token_eos(model_dft)
    ) {
xuxzh1's avatar
update  
xuxzh1 committed
108
        LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
109
110
111
112
113
114
115
116
117
118
119
        return 1;
    }

    {
        const int n_vocab_tgt = llama_n_vocab(model_tgt);
        const int n_vocab_dft = llama_n_vocab(model_dft);
        const int vocab_diff  = n_vocab_tgt > n_vocab_dft
            ? n_vocab_tgt - n_vocab_dft
            : n_vocab_dft - n_vocab_tgt;

        if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
xuxzh1's avatar
update  
xuxzh1 committed
120
121
            LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
            LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
xuxzh1's avatar
init  
xuxzh1 committed
122
123
124
125
126
127
128
129
                    n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
            return 1;
        }

        for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
            const char * token_text_tgt = llama_token_get_text(model_tgt, i);
            const char * token_text_dft = llama_token_get_text(model_dft, i);
            if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
xuxzh1's avatar
update  
xuxzh1 committed
130
131
132
133
                LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
                LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
                        common_token_to_piece(ctx_tgt, i).c_str(),
                        common_token_to_piece(ctx_dft, i).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
134
135
136
137
138
139
140
141
                return 1;
            }
        }
    }


    // Tokenize the prompt
    std::vector<llama_token> inp;
xuxzh1's avatar
update  
xuxzh1 committed
142
    inp = common_tokenize(ctx_tgt, params.prompt, true, true);
xuxzh1's avatar
init  
xuxzh1 committed
143
144
145
146
147

    const int max_context_size     = llama_n_ctx(ctx_tgt);
    const int max_tokens_list_size = max_context_size - 4;

    if ((int) inp.size() > max_tokens_list_size) {
xuxzh1's avatar
update  
xuxzh1 committed
148
        LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
xuxzh1's avatar
init  
xuxzh1 committed
149
150
151
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
152
    LOG("\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
153
154

    for (auto id : inp) {
xuxzh1's avatar
update  
xuxzh1 committed
155
        LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
156
157
158
159
160
161
162
    }

    const int n_input = inp.size();

    const auto t_enc_start = ggml_time_us();

    // eval the prompt with both models
xuxzh1's avatar
update  
xuxzh1 committed
163
164
165
    llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
    llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(),           1));
    llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
xuxzh1's avatar
init  
xuxzh1 committed
166
167
168
169
170
171
172

    const auto t_enc_end = ggml_time_us();

    // the 2 models should have the same vocab
    //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));

    // how many tokens to draft each time
xuxzh1's avatar
update  
xuxzh1 committed
173
    int n_draft = params.speculative.n_max;
xuxzh1's avatar
init  
xuxzh1 committed
174
175
176
177
178
179
180
181
182
183
184

    int n_predict = 0;
    int n_drafted = 0;
    int n_accept  = 0;

    int n_past_tgt = inp.size();
    int n_past_dft = inp.size();

    // used to determine end of generation
    bool has_eos = false;

xuxzh1's avatar
update  
xuxzh1 committed
185
186
    // target model sampling context (reuse the llama_context's sampling instance)
    struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
xuxzh1's avatar
init  
xuxzh1 committed
187
188
189
190
191

    // draft sequence data
    std::vector<seq_draft> drafts(n_seq_dft);

    for (int s = 0; s < n_seq_dft; ++s) {
xuxzh1's avatar
update  
xuxzh1 committed
192
193
        // allocate llama_sampler for each draft sequence
        drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
xuxzh1's avatar
init  
xuxzh1 committed
194
195
    }

xuxzh1's avatar
update  
xuxzh1 committed
196
197
    llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
    llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
xuxzh1's avatar
init  
xuxzh1 committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    const auto t_dec_start = ggml_time_us();

    // sample from the last token of the prompt
    drafts[0].i_batch_tgt.resize(1);
    drafts[0].i_batch_tgt[0] = 0;

    while (true) {
        std::set<int> active_seqs = {};

        // print current draft sequences
        for (int s = 0; s < n_seq_dft; ++s) {
            if (!drafts[s].active) {
                continue;
            }

            active_seqs.insert(s);
            const auto & tokens = drafts[s].tokens;

xuxzh1's avatar
update  
xuxzh1 committed
217
            LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        }

        int i_dft  = 0;
        int s_keep = 0;

        llama_token token_id;
        std::string token_str;

        // loop until we fail to accept a drafted token or we run out of drafted tokens
        while (true) {

            // check if the target token matches any of the drafts
            // for stochastic sampling, attempt to match the token with the drafted tokens
            {
                bool accept = false;
xuxzh1's avatar
update  
xuxzh1 committed
233
                if (params.sampling.temp > 0) {
xuxzh1's avatar
init  
xuxzh1 committed
234
                    // stochastic verification
xuxzh1's avatar
update  
xuxzh1 committed
235
                    common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
xuxzh1's avatar
init  
xuxzh1 committed
236

xuxzh1's avatar
update  
xuxzh1 committed
237
                    auto & dist_tgt = *common_sampler_get_candidates(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
238

xuxzh1's avatar
update  
xuxzh1 committed
239
240
                    float p_tgt = 0.0f;
                    float p_dft = 0.0f;
xuxzh1's avatar
init  
xuxzh1 committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

                    while (active_seqs.size() > 0) {
                        // randomly select a sequence to verify from active sequences
                        std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
                        int s = *std::next(active_seqs.begin(), u_int_dist(rng));
                        if (i_dft >= (int) drafts[s].tokens.size()) {
                            drafts[s].active = false;
                            active_seqs.erase(s);
                            continue;
                        }
                        if (accept) {
                            // if we already accepted a token, we can skip the rest
                            if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
                                drafts[s].active = false;
                                active_seqs.erase(s);
                            }
                            continue;
                        }
xuxzh1's avatar
update  
xuxzh1 committed
259
260

                        LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
xuxzh1's avatar
init  
xuxzh1 committed
261
                        float r = u_dist(rng);
xuxzh1's avatar
update  
xuxzh1 committed
262
263
264
265
                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };

                        //GGML_ASSERT(dist_tgt.size <= dist_dft.size);

xuxzh1's avatar
init  
xuxzh1 committed
266
267
268
269
                        // acquire the token probabilities assigned by the draft and target models
                        for (size_t i = 0; i < dist_tgt.size; i++) {
                            if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
                                p_tgt = dist_tgt.data[i].p;
xuxzh1's avatar
update  
xuxzh1 committed
270
                                break;
xuxzh1's avatar
init  
xuxzh1 committed
271
                            }
xuxzh1's avatar
update  
xuxzh1 committed
272
273
                        }
                        for (size_t i = 0; i < dist_dft.size; i++) {
xuxzh1's avatar
init  
xuxzh1 committed
274
275
276
277
278
                            if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
                                p_dft = dist_dft.data[i].p;
                                break;
                            }
                        }
xuxzh1's avatar
update  
xuxzh1 committed
279
                        LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
xuxzh1's avatar
init  
xuxzh1 committed
280
281
282
283
                        if (r <= p_tgt / p_dft) {
                            s_keep = s;
                            accept = true;
                            token_id = drafts[s].tokens[i_dft];
xuxzh1's avatar
update  
xuxzh1 committed
284
285
                            token_str = common_token_to_piece(ctx_tgt, token_id);
                            common_sampler_accept(smpl, token_id, true);
xuxzh1's avatar
init  
xuxzh1 committed
286

xuxzh1's avatar
update  
xuxzh1 committed
287
                            LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
288
289
                            break;
                        } else {
xuxzh1's avatar
update  
xuxzh1 committed
290
                            LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                            drafts[s].active = false;

                            // calculate residual probability
                            GGML_ASSERT(dist_tgt.sorted);
                            GGML_ASSERT(dist_dft.sorted);

                            // sort dist by id
                            std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
                                return a.id < b.id;
                            });
                            std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
                                return a.id < b.id;
                            });

xuxzh1's avatar
update  
xuxzh1 committed
305
306
                            float sum_probs = 0.0f;

xuxzh1's avatar
init  
xuxzh1 committed
307
                            for (size_t i = 0; i < dist_tgt.size; i++) {
xuxzh1's avatar
update  
xuxzh1 committed
308
309
310
311
312
313
                                if (i < dist_dft.size) {
                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
                                } else {
                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
                                }

xuxzh1's avatar
init  
xuxzh1 committed
314
315
                                sum_probs += dist_tgt.data[i].p;
                            }
xuxzh1's avatar
update  
xuxzh1 committed
316

xuxzh1's avatar
init  
xuxzh1 committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
                            for (size_t i = 0; i < dist_tgt.size; i++) {
                                dist_tgt.data[i].p /= sum_probs;
                            }

                            // sort dist_tgt by p desc
                            std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
                                return a.p > b.p;
                            });
                        }

                        active_seqs.erase(s);
                        for(int i = 0; i < n_seq_dft; i++) {
                            if (i == s) {
                                continue;
                            }
                            if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
                                // synchronize active status for sequences with the same drafted token
                                drafts[i].active = drafts[i].active && accept;
                                if (!drafts[i].active) {
                                    active_seqs.erase(s);
                                }
                            }
                        }
                    }

                    if (!accept) {
                        // all drafted tokens were rejected
                        // sample from the target model
xuxzh1's avatar
update  
xuxzh1 committed
345
346
347
348
349
350
351
352
353
                        LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n");
                        std::vector<float> probs(dist_tgt.size);
                        for (size_t i = 0; i < dist_tgt.size; ++i) {
                            probs[i] = dist_tgt.data[i].p;
                        }

                        std::discrete_distribution<> dist(probs.begin(), probs.end());

                        const int idx = dist(rng);
xuxzh1's avatar
init  
xuxzh1 committed
354

xuxzh1's avatar
update  
xuxzh1 committed
355
356
357
358
                        token_id = dist_tgt.data[idx].id;
                        common_sampler_accept(smpl, token_id, true);
                        token_str = common_token_to_piece(ctx_tgt, token_id);
                    }
xuxzh1's avatar
init  
xuxzh1 committed
359
360
361
362
                } else {
                    // greedy verification

                    // sample from the target model
xuxzh1's avatar
update  
xuxzh1 committed
363
364
                    LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
                    token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
xuxzh1's avatar
init  
xuxzh1 committed
365

xuxzh1's avatar
update  
xuxzh1 committed
366
                    common_sampler_accept(smpl, token_id, true);
xuxzh1's avatar
init  
xuxzh1 committed
367

xuxzh1's avatar
update  
xuxzh1 committed
368
                    token_str = common_token_to_piece(ctx_tgt, token_id);
xuxzh1's avatar
init  
xuxzh1 committed
369
370
371
372
373
374
375

                    for (int s = 0; s < n_seq_dft; ++s) {
                        if (!drafts[s].active) {
                            continue;
                        }

                        if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
xuxzh1's avatar
update  
xuxzh1 committed
376
                            LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

                            s_keep = s;
                            accept = true;
                        } else {
                            drafts[s].active = false;
                        }
                    }
                }

                if (llama_token_is_eog(model_tgt, token_id)) {
                    has_eos = true;
                }
                ++n_predict;

                if (accept) {
                    ++n_accept;
                    ++n_past_tgt;
                    ++n_past_dft;
                    ++i_dft;
                    if (params.use_color) {
                        // Color token according to its origin sequence
xuxzh1's avatar
update  
xuxzh1 committed
398
                        LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
399
                    } else {
xuxzh1's avatar
update  
xuxzh1 committed
400
                        LOG("%s", token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
401
402
403
                    }
                    continue;
                } else {
xuxzh1's avatar
update  
xuxzh1 committed
404
                    LOG("%s", token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
405
406
407
408
409
410
                    break;
                }
            }
        }

        {
xuxzh1's avatar
update  
xuxzh1 committed
411
            LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
412
413
414

            // TODO: simplify
            {
xuxzh1's avatar
update  
xuxzh1 committed
415
                LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
xuxzh1's avatar
init  
xuxzh1 committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

                llama_kv_cache_seq_keep(ctx_dft, s_keep);
                llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
                llama_kv_cache_seq_keep(ctx_dft, 0);

                llama_kv_cache_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
                llama_kv_cache_seq_keep(ctx_tgt, s_keep);
                llama_kv_cache_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
                llama_kv_cache_seq_keep(ctx_tgt, 0);
            }

            for (int s = 0; s < n_seq_dft; ++s) {
                drafts[s].active = false;
                drafts[s].tokens.clear();
                drafts[s].i_batch_tgt.clear();
                drafts[s].dists.clear();
            }
            // note: will be erased after the speculation phase
            drafts[0].tokens.push_back(token_id);
            drafts[0].dists.push_back(std::vector<llama_token_data>());
            drafts[0].i_batch_tgt.push_back(0);

xuxzh1's avatar
update  
xuxzh1 committed
438
439
            common_batch_clear(batch_dft);
            common_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true);
xuxzh1's avatar
init  
xuxzh1 committed
440
441

            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
xuxzh1's avatar
update  
xuxzh1 committed
442
            // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
443
444
445
446
447
            llama_decode(ctx_dft, batch_dft);

            ++n_past_dft;
        }

xuxzh1's avatar
update  
xuxzh1 committed
448
        if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
xuxzh1's avatar
init  
xuxzh1 committed
449
450
451
            break;
        }

xuxzh1's avatar
update  
xuxzh1 committed
452
453
454
455
        if (drafts[0].smpl) {
            common_sampler_free(drafts[0].smpl);
        }
        drafts[0].smpl = common_sampler_clone(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
456
457
458
459
460
461
462
463
464
465
466
467

        int n_seq_cur  = 1;
        int n_past_cur = n_past_dft;

        for (int s = 0; s < n_seq_dft; ++s) {
            drafts[s].active   = false;
            drafts[s].drafting = false;
        }
        drafts[0].active      = true;
        drafts[0].drafting    = true;
        drafts[0].i_batch_dft = 0;

xuxzh1's avatar
update  
xuxzh1 committed
468
469
        common_batch_clear(batch_tgt);
        common_batch_add  (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
xuxzh1's avatar
init  
xuxzh1 committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483

        // sample n_draft tokens from the draft model using tree-based sampling
        for (int i = 0; i < n_draft; ++i) {
            batch_dft.n_tokens = 0;

            for (int s = 0; s < n_seq_dft; ++s) {
                drafts[s].skip = false;
            }

            for (int s = 0; s < n_seq_dft; ++s) {
                if (!drafts[s].drafting || drafts[s].skip) {
                    continue;
                }

xuxzh1's avatar
update  
xuxzh1 committed
484
                common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
xuxzh1's avatar
init  
xuxzh1 committed
485

xuxzh1's avatar
update  
xuxzh1 committed
486
                const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
xuxzh1's avatar
init  
xuxzh1 committed
487

xuxzh1's avatar
update  
xuxzh1 committed
488
489
490
                for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
                    LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
                            k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
491
492
493
494
495
496
                }

                std::vector<int> sa(1, s);

                // attempt to split the branch if the probability is high enough
                for (int f = 1; f < 8; ++f) {
xuxzh1's avatar
update  
xuxzh1 committed
497
498
                    if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
                        LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
xuxzh1's avatar
init  
xuxzh1 committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

                        llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
                        llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);

                        // all previous tokens from this branch are now also part of the new branch
                        for (int t = 0; t < batch_tgt.n_tokens; ++t) {
                            for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
                                if (batch_tgt.seq_id[t][p] == s) {
                                    batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
                                    batch_tgt.n_seq_id[t]++;
                                    break;
                                }
                            }
                        }

                        // copy the draft state
                        drafts[n_seq_cur].active   = true;
                        drafts[n_seq_cur].drafting = true;
                        drafts[n_seq_cur].skip     = true;

                        drafts[n_seq_cur].tokens      = drafts[s].tokens;
                        drafts[n_seq_cur].dists       = drafts[s].dists;
                        drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
                        drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;

xuxzh1's avatar
update  
xuxzh1 committed
524
525
526
527
                        if (drafts[n_seq_cur].smpl) {
                            common_sampler_free(drafts[n_seq_cur].smpl);
                        }
                        drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
xuxzh1's avatar
init  
xuxzh1 committed
528
529
530
531
532
533
534
535
536
537
538

                        sa.push_back(n_seq_cur);

                        n_seq_cur++;
                    } else {
                        break;
                    }
                }

                // add drafted token for each sequence
                for (int is = 0; is < (int) sa.size(); ++is) {
xuxzh1's avatar
update  
xuxzh1 committed
539
                    const llama_token id = cur_p->data[is].id;
xuxzh1's avatar
init  
xuxzh1 committed
540
541
542

                    const int s = sa[is];

xuxzh1's avatar
update  
xuxzh1 committed
543
                    common_sampler_accept(drafts[s].smpl, id, true);
xuxzh1's avatar
init  
xuxzh1 committed
544
545
546

                    drafts[s].tokens.push_back(id);
                    // save cur_p.data into drafts[s].dists
xuxzh1's avatar
update  
xuxzh1 committed
547
                    drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
xuxzh1's avatar
init  
xuxzh1 committed
548
549
550
551

                    // add unique drafted tokens to the target batch
                    drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);

xuxzh1's avatar
update  
xuxzh1 committed
552
                    common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
xuxzh1's avatar
init  
xuxzh1 committed
553
554
555
556

                    // add the token to the batch for batched decoding with the draft model
                    drafts[s].i_batch_dft = batch_dft.n_tokens;

xuxzh1's avatar
update  
xuxzh1 committed
557
                    common_batch_add(batch_dft, id, n_past_cur, { s }, true);
xuxzh1's avatar
init  
xuxzh1 committed
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586

                    if (batch_tgt.n_tokens > n_draft) {
                        drafts[s].drafting = false;
                    }
                }
            }

            // no sequence is drafting anymore
            if (batch_dft.n_tokens == 0) {
                break;
            }

            // evaluate the drafted tokens on the draft model
            llama_decode(ctx_dft, batch_dft);
            ++n_past_cur;
            ++n_drafted;

            if (batch_tgt.n_tokens > n_draft) {
                break;
            }
        }

        // evaluate the target model on the drafted tokens
        {
            llama_kv_cache_seq_keep(ctx_tgt, 0);
            for (int s = 1; s < n_seq_dft; ++s) {
                llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
            }

xuxzh1's avatar
update  
xuxzh1 committed
587
            // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
            llama_decode(ctx_tgt, batch_tgt);
            ++n_past_tgt;
        }

        // the first token is always proposed by the target model before the speculation loop so we erase it here
        for (int s = 0; s < n_seq_dft; ++s) {
            if (!drafts[s].active) {
                continue;
            }

            drafts[s].tokens.erase(drafts[s].tokens.begin());
            drafts[s].dists.erase(drafts[s].dists.begin());
        }
    }

    auto t_dec_end = ggml_time_us();

xuxzh1's avatar
update  
xuxzh1 committed
605
    LOG("\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
606

xuxzh1's avatar
update  
xuxzh1 committed
607
608
    LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
    LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
xuxzh1's avatar
init  
xuxzh1 committed
609

xuxzh1's avatar
update  
xuxzh1 committed
610
611
612
613
614
615
    LOG_INF("\n");
    LOG_INF("n_draft   = %d\n", n_draft);
    LOG_INF("n_predict = %d\n", n_predict);
    LOG_INF("n_drafted = %d\n", n_drafted);
    LOG_INF("n_accept  = %d\n", n_accept);
    LOG_INF("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
xuxzh1's avatar
init  
xuxzh1 committed
616

xuxzh1's avatar
update  
xuxzh1 committed
617
618
619
620
    LOG_INF("\n");
    LOG_INF("draft:\n\n");
    // TODO: print sampling/grammar timings for all drafts
    llama_perf_context_print(ctx_dft);
xuxzh1's avatar
init  
xuxzh1 committed
621

xuxzh1's avatar
update  
xuxzh1 committed
622
623
624
    LOG_INF("\n");
    LOG_INF("target:\n\n");
    common_perf_print(ctx_tgt, smpl);
xuxzh1's avatar
init  
xuxzh1 committed
625

xuxzh1's avatar
update  
xuxzh1 committed
626
    common_sampler_free(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
627
    for (int s = 0; s < n_seq_dft; ++s) {
xuxzh1's avatar
update  
xuxzh1 committed
628
        common_sampler_free(drafts[s].smpl);
xuxzh1's avatar
init  
xuxzh1 committed
629
630
631
632
633
634
635
636
637
638
639
640
    }

    llama_batch_free(batch_dft);

    llama_free(ctx_tgt);
    llama_free_model(model_tgt);

    llama_free(ctx_dft);
    llama_free_model(model_dft);

    llama_backend_free();

xuxzh1's avatar
update  
xuxzh1 committed
641
    LOG("\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
642
643
644

    return 0;
}