queue.rs 21.9 KB
Newer Older
OlivierDehaene's avatar
OlivierDehaene committed
1
2
3
4
use crate::infer::{InferError, InferStreamResponse};
use crate::validation::{
    ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
5
use nohash_hasher::{BuildNoHashHasher, IntMap};
6
use std::cmp::min;
7
use std::collections::VecDeque;
OlivierDehaene's avatar
OlivierDehaene committed
8
9
10
use text_generation_client::v2::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
11
use text_generation_client::ChunksToString;
OlivierDehaene's avatar
OlivierDehaene committed
12
use tokio::sync::{mpsc, oneshot};
13
use tokio::time::Instant;
14
use tracing::{info_span, instrument, Span};
15
16
17
18
19
20
21

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
22
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
23
24
25
26
27
28
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
29
30
31
32
33
34
35
36
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
37
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
38
39
40
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
41
42
43
44
45
46
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
47
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
48
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
49
50

        // Launch background queue task
51
52
53
54
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
55
            speculate,
56
57
            queue_receiver,
        ));
58
59
60
61

        Self { queue_sender }
    }

62
    #[instrument(skip_all)]
63
64
65
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
66
        self.queue_sender
67
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
68
            .unwrap();
69
70
71
    }

    // Get the next batch
72
    #[instrument(skip(self))]
73
74
75
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
76
        max_size: Option<usize>,
77
        prefill_token_budget: u32,
78
        token_budget: u32,
79
80
81
82
83
84
85
86
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
87
                max_size,
88
                prefill_token_budget,
89
                token_budget,
90
                response_sender,
91
                span: Span::current(),
92
93
94
95
96
97
98
99
100
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
101
102
103
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
104
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
105
    speculate: u32,
OlivierDehaene's avatar
OlivierDehaene committed
106
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
107
) {
Nicolas Patry's avatar
Nicolas Patry committed
108
    let mut state = State::new(requires_padding, block_size, window_size, speculate);
109

OlivierDehaene's avatar
OlivierDehaene committed
110
    while let Some(cmd) = receiver.recv().await {
111
        match cmd {
112
            QueueCommand::Append(entry, span) => {
113
                span.in_scope(|| state.append(*entry));
114
                metrics::gauge!("tgi_queue_size").increment(1.0);
115
            }
116
117
            QueueCommand::NextBatch {
                min_size,
118
                max_size,
119
                prefill_token_budget,
120
                token_budget,
121
                response_sender,
122
123
                span,
            } => span.in_scope(|| {
124
125
                let next_batch =
                    state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
126
                response_sender.send(next_batch).unwrap();
127
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
128
            }),
129
130
131
132
133
134
135
136
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
137
    entries: VecDeque<(u64, Entry)>,
138
139
140
141
142
143

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
144
145
146

    /// Whether the model is using padding
    requires_padding: bool,
147
148
149

    /// Paged Attention block size
    block_size: u32,
150
151
152

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
153
154
155

    /// Speculation amount
    speculate: u32,
156
157
158
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
159
160
161
162
163
164
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
    ) -> Self {
165
        Self {
166
            entries: VecDeque::with_capacity(128),
167
168
            next_id: 0,
            next_batch_id: 0,
169
            requires_padding,
170
            block_size,
171
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
172
            speculate,
173
174
175
176
        }
    }

    /// Append an entry to the queue
177
178
179
180
181
182
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
183
        self.entries.push_back((self.next_id, entry));
184
185
186
187
        self.next_id += 1;
    }

    // Get the next batch
188
189
190
    fn next_batch(
        &mut self,
        min_size: Option<usize>,
191
        max_size: Option<usize>,
192
193
194
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
195
        if self.entries.is_empty() {
196
            tracing::debug!("No queue");
197
198
199
200
201
202
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
203
                tracing::debug!("Not enough entries");
204
205
206
207
                return None;
            }
        }

208
209
210
211
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

212
        // Create span for this batch to add context to inference calls
213
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
214
215
        next_batch_span.follows_from(&Span::current());

216
        let mut batch_requests = Vec::with_capacity(self.entries.len());
217
        let mut batch_entries =
218
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
219

220
221
222
223
224
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;

        // Pop entries starting from the front of the queue
225
226
227
        while let Some((id, mut entry)) = self.entries.pop_front() {
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
228
            if entry.response_tx.is_closed() {
229
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
230
                tracing::debug!("Dropping entry");
231
232
233
                continue;
            }

234
235
236
237
238
239
            if self.requires_padding {
                // We pad to max input length in the Python shards
                // We need to take these padding tokens into the equation
                max_input_length = max_input_length.max(entry.request.input_length);
                prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
            } else {
240
241
242
243
                // pad to block size
                prefill_tokens += ((entry.request.input_length + self.block_size - 1)
                    / self.block_size)
                    * self.block_size;
244
245
            }

246
247
248
            if self.requires_padding {
                decode_tokens += entry.request.stopping_parameters.max_new_tokens;
            } else {
249
250
251
252
253
254
255
256
                let max_new_tokens = match self.window_size {
                    None => entry.request.stopping_parameters.max_new_tokens,
                    Some(window_size) => min(
                        window_size.saturating_sub(entry.request.input_length),
                        entry.request.stopping_parameters.max_new_tokens,
                    ),
                };

257
258
                // pad to block size
                decode_tokens +=
259
                    ((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
260
            }
261

262
            if prefill_tokens > prefill_token_budget
Nicolas Patry's avatar
Nicolas Patry committed
263
                || (prefill_tokens + decode_tokens + self.speculate) > token_budget
264
            {
265
266
                // Entry is over budget
                // Add it back to the front
267
                tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
268
269
270
271
                self.entries.push_front((id, entry));
                break;
            }

272
            tracing::debug!("Accepting entry");
273
274
275
276
277
278
279
280
281
282
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

            batch_requests.push(Request {
                id,
283
                prefill_logprobs: entry.request.decoder_input_details,
284
                inputs: entry.request.inputs.chunks_to_string(),
285
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
286
287
288
289
290
291
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
292
                top_n_tokens: entry.request.top_n_tokens,
293
            });
294
295
296
297
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
298
299
300
301
302

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
303
304
        }

305
        // Empty batch
306
        if batch_requests.is_empty() {
OlivierDehaene's avatar
OlivierDehaene committed
307
            tracing::debug!("Filtered out all entries");
308
309
310
            return None;
        }

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
327
328
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
329
330
331
332

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
333
            size,
334
            max_tokens: (prefill_tokens + decode_tokens),
335
336
337
338
        };
        // Increment batch id
        self.next_batch_id += 1;

339
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
340

341
        Some((batch_entries, batch, next_batch_span))
342
343
344
    }
}

345
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
346
347
348

#[derive(Debug)]
enum QueueCommand {
349
    Append(Box<Entry>, Span),
350
351
    NextBatch {
        min_size: Option<usize>,
352
        max_size: Option<usize>,
353
        prefill_token_budget: u32,
354
        token_budget: u32,
355
        response_sender: oneshot::Sender<Option<NextBatch>>,
356
        span: Span,
357
358
359
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

397
398
399
#[cfg(test)]
mod tests {
    use super::*;
400
    use tracing::info_span;
401

402
403
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
404
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
405
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
406
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
407

408
        let entry = Entry {
409
            request: ValidGenerateRequest {
410
                inputs: vec![],
411
                input_length: 0,
412
                truncate: 0,
413
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
414
                parameters: ValidParameters {
415
416
417
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
418
                    typical_p: 0.0,
419
420
421
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
422
                    frequency_penalty: 0.0,
423
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
424
                    grammar: None,
425
                },
OlivierDehaene's avatar
OlivierDehaene committed
426
                stopping_parameters: ValidStoppingParameters {
427
                    ignore_eos_token: false,
428
                    max_new_tokens: 1,
429
430
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
431
                top_n_tokens: 0,
drbh's avatar
drbh committed
432
                adapter_id: None,
433
434
            },
            response_tx,
435
436
437
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
438
            batch_time: None,
439
440
        };
        (entry, receiver_tx)
441
442
443
444
    }

    #[test]
    fn test_append() {
Nicolas Patry's avatar
Nicolas Patry committed
445
        let mut state = State::new(false, 1, None, 0);
446
        let (entry, _guard) = default_entry();
447
448
449
450
451
452
453
454

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
455
        let (id, _) = state.entries.remove(0).unwrap();
456
457
458
459
460
        assert_eq!(id, 0);
    }

    #[test]
    fn test_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
461
        let mut state = State::new(false, 1, None, 0);
462

463
464
        assert!(state.next_batch(None, None, 1, 1).is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).is_none());
465
466
467
468
    }

    #[test]
    fn test_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
469
        let mut state = State::new(false, 1, None, 0);
470
471
472
473
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
474

475
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
476
477
478
479
480
481
482
483
484
485
486
487
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

488
489
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
490

491
        assert!(state.next_batch(Some(2), None, 2, 2).is_none());
492
493
494

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
495
        let (id, _) = state.entries.remove(0).unwrap();
496
497
498
        assert_eq!(id, 2);
    }

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    #[test]
    fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

519
    #[test]
520
    fn test_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
521
        let mut state = State::new(false, 1, None, 0);
522
523
524
525
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
526

527
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
528
529
530
531
532
533
534
535
536
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

537
538
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
539

540
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
541
542
543
544
545
546
547
548
549
550
551
552
553
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
Nicolas Patry's avatar
Nicolas Patry committed
554
        let queue = Queue::new(false, 1, None, 0);
555
556
        let (entry, _guard) = default_entry();
        queue.append(entry);
557
558
559
560
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
Nicolas Patry's avatar
Nicolas Patry committed
561
        let queue = Queue::new(false, 1, None, 0);
562

563
564
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
565
566
567
568
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
Nicolas Patry's avatar
Nicolas Patry committed
569
        let queue = Queue::new(false, 1, None, 0);
570
571
572
573
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
574

575
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
576
577
578
579
580
581
582
583
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

584
585
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
586

587
        // Not enough requests pending
588
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
589
        // Not enough token budget
590
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
591
        // Ok
592
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
593
594
595
596
597
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
598
599
    }

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
        let queue = Queue::new(false, 1, None, 0);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

616
    #[tokio::test]
617
    async fn test_queue_next_batch_token_budget() {
Nicolas Patry's avatar
Nicolas Patry committed
618
        let queue = Queue::new(false, 1, None, 0);
619
620
621
622
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
623

624
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
625
626
627
628
629
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

630
631
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
632

633
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
634
635
636
637
638
639
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
640

Nicolas Patry's avatar
Nicolas Patry committed
641
642
643
644
645
646
647
648
649
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
        let queue = Queue::new(false, 1, None, 2);
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
650
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
651

652
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
653
654
655
656
657
658
659
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

660
661
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
Nicolas Patry's avatar
Nicolas Patry committed
662
        let queue = Queue::new(false, 1, None, 0);
663
664
665
        let (entry, _) = default_entry();
        queue.append(entry);

666
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
667
    }
668
}