"vscode:/vscode.git/clone" did not exist on "14cb544d56b06b25483c4cf9c817b657acff8604"
queue.rs 22.3 KB
Newer Older
1
2
use crate::client::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
OlivierDehaene's avatar
OlivierDehaene committed
3
};
4
use nohash_hasher::{BuildNoHashHasher, IntMap};
5
use std::cmp::min;
6
use std::collections::VecDeque;
7
8
9
10
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
    ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
OlivierDehaene's avatar
OlivierDehaene committed
11
};
OlivierDehaene's avatar
OlivierDehaene committed
12
use tokio::sync::{mpsc, oneshot};
13
use tokio::time::Instant;
14
use tracing::{info_span, instrument, Span};
15
16
17
18
19
20
21

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
22
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
23
24
25
26
27
28
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
29
30
31
32
33
34
35
36
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
37
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
38
39
40
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
41
42
43
44
45
46
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
47
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
48
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
49
50

        // Launch background queue task
51
52
53
54
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
55
            speculate,
56
57
            queue_receiver,
        ));
58
59
60
61

        Self { queue_sender }
    }

62
    #[instrument(skip_all)]
63
64
65
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
66
        self.queue_sender
67
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
68
            .unwrap();
69
70
71
    }

    // Get the next batch
72
    #[instrument(skip(self))]
73
74
75
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
76
        max_size: Option<usize>,
77
        prefill_token_budget: u32,
78
        token_budget: u32,
79
80
81
82
83
84
85
86
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
87
                max_size,
88
                prefill_token_budget,
89
                token_budget,
90
                response_sender,
91
                span: Span::current(),
92
93
94
95
96
97
98
99
100
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
101
102
103
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
104
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
105
    speculate: u32,
OlivierDehaene's avatar
OlivierDehaene committed
106
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
107
) {
Nicolas Patry's avatar
Nicolas Patry committed
108
    let mut state = State::new(requires_padding, block_size, window_size, speculate);
109

OlivierDehaene's avatar
OlivierDehaene committed
110
    while let Some(cmd) = receiver.recv().await {
111
        match cmd {
112
            QueueCommand::Append(entry, span) => {
113
                span.in_scope(|| state.append(*entry));
114
                metrics::gauge!("tgi_queue_size").increment(1.0);
115
            }
116
117
            QueueCommand::NextBatch {
                min_size,
118
                max_size,
119
                prefill_token_budget,
120
                token_budget,
121
                response_sender,
122
123
                span,
            } => span.in_scope(|| {
124
125
                let next_batch =
                    state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
126
                response_sender.send(next_batch).unwrap();
127
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
128
            }),
129
130
131
132
133
134
135
136
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
137
    entries: VecDeque<(u64, Entry)>,
138
139
140
141
142
143

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
144
145
146

    /// Whether the model is using padding
    requires_padding: bool,
147
148
149

    /// Paged Attention block size
    block_size: u32,
150
151
152

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
153
154
155

    /// Speculation amount
    speculate: u32,
156
157
158
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
159
160
161
162
163
164
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
165
        Self {
166
            entries: VecDeque::with_capacity(128),
167
168
            next_id: 0,
            next_batch_id: 0,
169
            requires_padding,
170
            block_size,
171
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
172
            speculate,
173
174
175
176
        }
    }

    /// Append an entry to the queue
177
178
179
180
181
182
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
183
        self.entries.push_back((self.next_id, entry));
184
185
186
187
        self.next_id += 1;
    }

    // Get the next batch
188
189
190
    fn next_batch(
        &mut self,
        min_size: Option<usize>,
191
        max_size: Option<usize>,
192
193
194
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
195
        if self.entries.is_empty() {
196
            tracing::debug!("No queue");
197
198
199
200
201
202
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
203
                tracing::debug!("Not enough entries");
204
205
206
207
                return None;
            }
        }

drbh's avatar
drbh committed
208
209
210
211
212
213
214
        if let Some(max_size) = max_size {
            if max_size == 0 {
                tracing::debug!("No capacity");
                return None;
            }
        }

215
216
217
218
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

219
        // Create span for this batch to add context to inference calls
220
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
221
        next_batch_span.follows_from(Span::current());
222

223
        let mut batch_requests = Vec::with_capacity(self.entries.len());
224
        let mut batch_entries =
225
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
226

227
228
229
230
231
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;

        // Pop entries starting from the front of the queue
232
233
234
        while let Some((id, mut entry)) = self.entries.pop_front() {
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
235
            if entry.response_tx.is_closed() {
236
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
237
                tracing::debug!("Dropping entry");
238
239
240
                continue;
            }

241
242
243
244
245
246
            if self.requires_padding {
                // We pad to max input length in the Python shards
                // We need to take these padding tokens into the equation
                max_input_length = max_input_length.max(entry.request.input_length);
                prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
            } else {
247
248
249
250
                // pad to block size
                prefill_tokens += ((entry.request.input_length + self.block_size - 1)
                    / self.block_size)
                    * self.block_size;
251
252
            }

253
254
255
            if self.requires_padding {
                decode_tokens += entry.request.stopping_parameters.max_new_tokens;
            } else {
256
257
258
259
260
261
262
263
                let max_new_tokens = match self.window_size {
                    None => entry.request.stopping_parameters.max_new_tokens,
                    Some(window_size) => min(
                        window_size.saturating_sub(entry.request.input_length),
                        entry.request.stopping_parameters.max_new_tokens,
                    ),
                };

264
265
                // pad to block size
                decode_tokens +=
266
                    ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
267
            }
268

269
            if prefill_tokens > prefill_token_budget
Nicolas Patry's avatar
Nicolas Patry committed
270
                || (prefill_tokens + decode_tokens + self.speculate) > token_budget
271
            {
272
273
                // Entry is over budget
                // Add it back to the front
274
                tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
275
276
277
278
                self.entries.push_front((id, entry));
                break;
            }

279
            tracing::debug!("Accepting entry");
280
281
282
283
284
285
286
287
288
289
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

            batch_requests.push(Request {
                id,
290
                prefill_logprobs: entry.request.decoder_input_details,
291
                inputs: entry.request.inputs.chunks_to_string(),
292
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
293
294
295
296
297
298
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
299
                top_n_tokens: entry.request.top_n_tokens,
300
            });
301
302
303
304
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
305
306
307
308
309

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
310
311
        }

312
        // Empty batch
313
        if batch_requests.is_empty() {
OlivierDehaene's avatar
OlivierDehaene committed
314
            tracing::debug!("Filtered out all entries");
315
316
317
            return None;
        }

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
334
335
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
336
337
338
339

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
340
            size,
341
            max_tokens: (prefill_tokens + decode_tokens),
342
343
344
345
        };
        // Increment batch id
        self.next_batch_id += 1;

346
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
347

348
        Some((batch_entries, batch, next_batch_span))
349
350
351
    }
}

352
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
353
354
355

#[derive(Debug)]
enum QueueCommand {
356
    Append(Box<Entry>, Span),
357
358
    NextBatch {
        min_size: Option<usize>,
359
        max_size: Option<usize>,
360
        prefill_token_budget: u32,
361
        token_budget: u32,
362
        response_sender: oneshot::Sender<Option<NextBatch>>,
363
        span: Span,
364
365
366
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

404
405
406
#[cfg(test)]
mod tests {
    use super::*;
407
    use std::sync::Arc;
408
    use tracing::info_span;
409

410
411
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
412
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
413
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
414
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
415

416
        let entry = Entry {
417
            request: ValidGenerateRequest {
418
                inputs: vec![],
419
                input_ids: Some(Arc::new(vec![])),
420
                input_length: 0,
421
                add_special_tokens: true,
422
                truncate: 0,
423
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
424
                parameters: ValidParameters {
425
426
427
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
428
                    typical_p: 0.0,
429
430
431
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
432
                    frequency_penalty: 0.0,
433
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
434
                    grammar: None,
435
                },
OlivierDehaene's avatar
OlivierDehaene committed
436
                stopping_parameters: ValidStoppingParameters {
437
                    ignore_eos_token: false,
438
                    max_new_tokens: 1,
439
                    max_total_new_tokens: 1024,
440
441
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
442
                top_n_tokens: 0,
drbh's avatar
drbh committed
443
                adapter_id: None,
444
445
            },
            response_tx,
446
447
448
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
449
            batch_time: None,
450
451
        };
        (entry, receiver_tx)
452
453
454
455
    }

    #[test]
    fn test_append() {
Nicolas Patry's avatar
Nicolas Patry committed
456
        let mut state = State::new(false, 1, None, 0);
457
        let (entry, _guard) = default_entry();
458
459
460
461
462
463
464
465

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
466
        let (id, _) = state.entries.remove(0).unwrap();
467
468
469
470
471
        assert_eq!(id, 0);
    }

    #[test]
    fn test_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
472
        let mut state = State::new(false, 1, None, 0);
473

474
475
        assert!(state.next_batch(None, None, 1, 1).is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).is_none());
476
477
478
479
    }

    #[test]
    fn test_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
480
        let mut state = State::new(false, 1, None, 0);
481
482
483
484
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
485

486
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
487
488
489
490
491
492
493
494
495
496
497
498
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

499
500
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
501

502
        assert!(state.next_batch(Some(2), None, 2, 2).is_none());
503
504
505

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
506
        let (id, _) = state.entries.remove(0).unwrap();
507
508
509
        assert_eq!(id, 2);
    }

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    #[test]
    fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

530
    #[test]
531
    fn test_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
532
        let mut state = State::new(false, 1, None, 0);
533
534
535
536
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
537

538
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
539
540
541
542
543
544
545
546
547
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

548
549
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
550

551
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
552
553
554
555
556
557
558
559
560
561
562
563
564
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
Nicolas Patry's avatar
Nicolas Patry committed
565
        let queue = Queue::new(false, 1, None, 0);
566
567
        let (entry, _guard) = default_entry();
        queue.append(entry);
568
569
570
571
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
572
        let queue = Queue::new(false, 1, None, 0);
573

574
575
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
576
577
578
579
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
580
        let queue = Queue::new(false, 1, None, 0);
581
582
583
584
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
585

586
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
587
588
589
590
591
592
593
594
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

595
596
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
597

598
        // Not enough requests pending
599
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
600
        // Not enough token budget
601
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
602
        // Ok
603
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
604
605
606
607
608
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
609
610
    }

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
        let queue = Queue::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

627
    #[tokio::test]
628
    async fn test_queue_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
629
        let queue = Queue::new(false, 1, None, 0);
630
631
632
633
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
634

635
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
636
637
638
639
640
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

641
642
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
643

644
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
645
646
647
648
649
650
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
651

Nicolas Patry's avatar
Nicolas Patry committed
652
653
654
655
656
657
658
659
660
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
        let queue = Queue::new(false, 1, None, 2);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
661
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
662

663
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
664
665
666
667
668
669
670
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

671
672
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
Nicolas Patry's avatar
Nicolas Patry committed
673
        let queue = Queue::new(false, 1, None, 0);
674
675
676
        let (entry, _) = default_entry();
        queue.append(entry);

677
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
678
    }
679
}