queue.rs 22 KB
Newer Older
OlivierDehaene's avatar
OlivierDehaene committed
1
2
3
4
use crate::infer::{InferError, InferStreamResponse};
use crate::validation::{
    ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
5
use nohash_hasher::{BuildNoHashHasher, IntMap};
6
use std::cmp::min;
7
use std::collections::VecDeque;
OlivierDehaene's avatar
OlivierDehaene committed
8
9
10
11
use text_generation_client::v3::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::{ChunksToString, Input};
OlivierDehaene's avatar
OlivierDehaene committed
12
use tokio::sync::{mpsc, oneshot};
13
use tokio::time::Instant;
14
use tracing::{info_span, instrument, Span};
15
16
17
18
19
20
21

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
22
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
23
24
25
26
27
28
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
29
30
31
32
33
34
35
36
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
37
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
38
39
40
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
41
42
43
44
45
46
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
47
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
48
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
49
50

        // Launch background queue task
51
52
53
54
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
55
            speculate,
56
57
            queue_receiver,
        ));
58
59
60
61

        Self { queue_sender }
    }

62
    #[instrument(skip_all)]
63
64
65
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
66
        self.queue_sender
67
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
68
            .unwrap();
69
70
71
    }

    // Get the next batch
72
    #[instrument(skip(self))]
73
74
75
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
76
        max_size: Option<usize>,
77
        prefill_token_budget: u32,
78
        token_budget: u32,
79
80
81
82
83
84
85
86
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
87
                max_size,
88
                prefill_token_budget,
89
                token_budget,
90
                response_sender,
91
                span: Span::current(),
92
93
94
95
96
97
98
99
100
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
101
102
103
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
104
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
105
    speculate: u32,
OlivierDehaene's avatar
OlivierDehaene committed
106
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
107
) {
Nicolas Patry's avatar
Nicolas Patry committed
108
    let mut state = State::new(requires_padding, block_size, window_size, speculate);
109

OlivierDehaene's avatar
OlivierDehaene committed
110
    while let Some(cmd) = receiver.recv().await {
111
        match cmd {
112
            QueueCommand::Append(entry, span) => {
113
                span.in_scope(|| state.append(*entry));
114
115
                metrics::increment_gauge!("tgi_queue_size", 1.0);
            }
116
117
            QueueCommand::NextBatch {
                min_size,
118
                max_size,
119
                prefill_token_budget,
120
                token_budget,
121
                response_sender,
122
123
                span,
            } => span.in_scope(|| {
124
125
                let next_batch =
                    state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
126
                response_sender.send(next_batch).unwrap();
127
                metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
128
            }),
129
130
131
132
133
134
135
136
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
137
    entries: VecDeque<(u64, Entry)>,
138
139
140
141
142
143

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
144
145
146

    /// Whether the model is using padding
    requires_padding: bool,
147
148
149

    /// Paged Attention block size
    block_size: u32,
150
151
152

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
153
154
155

    /// Speculation amount
    speculate: u32,
156
157
158
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
159
160
161
162
163
164
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
165
        Self {
166
            entries: VecDeque::with_capacity(128),
167
168
            next_id: 0,
            next_batch_id: 0,
169
            requires_padding,
170
            block_size,
171
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
172
            speculate,
173
174
175
176
        }
    }

    /// Append an entry to the queue
177
178
179
180
181
182
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
183
        self.entries.push_back((self.next_id, entry));
184
185
186
187
        self.next_id += 1;
    }

    // Get the next batch
188
189
190
    fn next_batch(
        &mut self,
        min_size: Option<usize>,
191
        max_size: Option<usize>,
192
193
194
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
195
        if self.entries.is_empty() {
196
            tracing::debug!("No queue");
197
198
199
200
201
202
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
203
                tracing::debug!("Not enough entries");
204
205
206
207
                return None;
            }
        }

208
209
210
211
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

212
        // Create span for this batch to add context to inference calls
213
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
214
215
        next_batch_span.follows_from(&Span::current());

216
        let mut batch_requests = Vec::with_capacity(self.entries.len());
217
        let mut batch_entries =
218
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
219

220
221
222
223
224
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;

        // Pop entries starting from the front of the queue
225
226
227
        while let Some((id, mut entry)) = self.entries.pop_front() {
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
228
            if entry.response_tx.is_closed() {
229
                metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
230
                tracing::debug!("Dropping entry");
231
232
233
                continue;
            }

234
235
236
237
238
239
            if self.requires_padding {
                // We pad to max input length in the Python shards
                // We need to take these padding tokens into the equation
                max_input_length = max_input_length.max(entry.request.input_length);
                prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
            } else {
240
241
242
243
                // pad to block size
                prefill_tokens += ((entry.request.input_length + self.block_size - 1)
                    / self.block_size)
                    * self.block_size;
244
245
            }

246
247
248
            if self.requires_padding {
                decode_tokens += entry.request.stopping_parameters.max_new_tokens;
            } else {
249
250
251
252
253
254
255
256
                let max_new_tokens = match self.window_size {
                    None => entry.request.stopping_parameters.max_new_tokens,
                    Some(window_size) => min(
                        window_size.saturating_sub(entry.request.input_length),
                        entry.request.stopping_parameters.max_new_tokens,
                    ),
                };

257
258
                // pad to block size
                decode_tokens +=
259
                    ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
260
            }
261

262
            if prefill_tokens > prefill_token_budget
Nicolas Patry's avatar
Nicolas Patry committed
263
                || (prefill_tokens + decode_tokens + self.speculate) > token_budget
264
            {
265
266
                // Entry is over budget
                // Add it back to the front
267
                tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
268
269
270
271
                self.entries.push_front((id, entry));
                break;
            }

272
            tracing::debug!("Accepting entry");
273
274
275
276
277
278
279
280
281
282
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

            batch_requests.push(Request {
                id,
283
                prefill_logprobs: entry.request.decoder_input_details,
OlivierDehaene's avatar
OlivierDehaene committed
284
                inputs: entry.request.inputs.chunks_to_string(),
285
286
287
                input_chunks: Some(Input {
                    chunks: entry.request.inputs.clone(),
                }),
288
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
289
290
291
292
293
294
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
295
                top_n_tokens: entry.request.top_n_tokens,
296
            });
297
298
299
300
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
301
302
303
304
305

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
306
307
        }

308
        // Empty batch
309
        if batch_requests.is_empty() {
310
            tracing::debug!("Filterered out all entries");
311
312
313
            return None;
        }

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
330
331
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
332
333
334
335

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
336
            size,
337
            max_tokens: (prefill_tokens + decode_tokens),
338
339
340
341
        };
        // Increment batch id
        self.next_batch_id += 1;

342
        metrics::histogram!("tgi_batch_next_size", batch.size as f64);
343

344
        Some((batch_entries, batch, next_batch_span))
345
346
347
    }
}

348
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
349
350
351

#[derive(Debug)]
enum QueueCommand {
352
    Append(Box<Entry>, Span),
353
354
    NextBatch {
        min_size: Option<usize>,
355
        max_size: Option<usize>,
356
        prefill_token_budget: u32,
357
        token_budget: u32,
358
        response_sender: oneshot::Sender<Option<NextBatch>>,
359
        span: Span,
360
361
362
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

400
401
402
#[cfg(test)]
mod tests {
    use super::*;
403
    use tracing::info_span;
404

405
406
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
407
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
408
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
409
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
410

411
        let entry = Entry {
412
            request: ValidGenerateRequest {
413
                inputs: vec![],
414
                input_length: 0,
415
                truncate: 0,
416
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
417
                parameters: ValidParameters {
418
419
420
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
421
                    typical_p: 0.0,
422
423
424
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
425
                    frequency_penalty: 0.0,
426
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
427
                    grammar: None,
428
                },
OlivierDehaene's avatar
OlivierDehaene committed
429
                stopping_parameters: ValidStoppingParameters {
430
                    ignore_eos_token: false,
431
                    max_new_tokens: 1,
432
433
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
434
                top_n_tokens: 0,
435
436
            },
            response_tx,
437
438
439
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
440
            batch_time: None,
441
442
        };
        (entry, receiver_tx)
443
444
445
446
    }

    #[test]
    fn test_append() {
Nicolas Patry's avatar
Nicolas Patry committed
447
        let mut state = State::new(false, 1, None, 0);
448
        let (entry, _guard) = default_entry();
449
450
451
452
453
454
455
456

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
457
        let (id, _) = state.entries.remove(0).unwrap();
458
459
460
461
462
        assert_eq!(id, 0);
    }

    #[test]
    fn test_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
463
        let mut state = State::new(false, 1, None, 0);
464

465
466
        assert!(state.next_batch(None, None, 1, 1).is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).is_none());
467
468
469
470
    }

    #[test]
    fn test_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
471
        let mut state = State::new(false, 1, None, 0);
472
473
474
475
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
476

477
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
478
479
480
481
482
483
484
485
486
487
488
489
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

490
491
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
492

493
        assert!(state.next_batch(Some(2), None, 2, 2).is_none());
494
495
496

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
497
        let (id, _) = state.entries.remove(0).unwrap();
498
499
500
        assert_eq!(id, 2);
    }

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    #[test]
    fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

521
    #[test]
522
    fn test_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
523
        let mut state = State::new(false, 1, None, 0);
524
525
526
527
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
528

529
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
530
531
532
533
534
535
536
537
538
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

539
540
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
541

542
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
543
544
545
546
547
548
549
550
551
552
553
554
555
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
Nicolas Patry's avatar
Nicolas Patry committed
556
        let queue = Queue::new(false, 1, None, 0);
557
558
        let (entry, _guard) = default_entry();
        queue.append(entry);
559
560
561
562
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
563
        let queue = Queue::new(false, 1, None, 0);
564

565
566
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
567
568
569
570
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
571
        let queue = Queue::new(false, 1, None, 0);
572
573
574
575
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
576

577
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
578
579
580
581
582
583
584
585
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

586
587
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
588

589
        // Not enough requests pending
590
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
591
        // Not enough token budget
592
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
593
        // Ok
594
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
595
596
597
598
599
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
600
601
    }

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
        let queue = Queue::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

618
    #[tokio::test]
619
    async fn test_queue_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
620
        let queue = Queue::new(false, 1, None, 0);
621
622
623
624
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
625

626
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
627
628
629
630
631
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

632
633
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
634

635
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
636
637
638
639
640
641
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
642

Nicolas Patry's avatar
Nicolas Patry committed
643
644
645
646
647
648
649
650
651
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
        let queue = Queue::new(false, 1, None, 2);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
652
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
653

654
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
655
656
657
658
659
660
661
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

662
663
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
Nicolas Patry's avatar
Nicolas Patry committed
664
        let queue = Queue::new(false, 1, None, 0);
665
666
667
        let (entry, _) = default_entry();
        queue.append(entry);

668
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
669
    }
670
}