queue.rs 24.9 KB
Newer Older
1
2
3
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
OlivierDehaene's avatar
OlivierDehaene committed
4
5
6
use crate::validation::{
    ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
7
use nohash_hasher::{BuildNoHashHasher, IntMap};
8
use std::cmp::{max, min};
9
use std::collections::VecDeque;
OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
use text_generation_client::v3::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
13
14
use text_generation_client::ChunksToString;
use text_generation_client::Input;
OlivierDehaene's avatar
OlivierDehaene committed
15
use tokio::sync::{mpsc, oneshot};
16
use tokio::time::Instant;
17
use tracing::{info_span, instrument, Instrument, Span};
18
19
20
21
22
23
24

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
25
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
26
27
28
29
30
31
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
32
33
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
34
35
    /// Block Allocation
    pub block_allocation: Option<BlockAllocation>,
36
37
38
39
40
41
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
42
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
43
44
45
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
50
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
51
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
52
    ) -> Self {
53
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
54
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
55
56

        // Launch background queue task
57
58
59
60
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
61
            speculate,
62
            max_batch_total_tokens,
63
64
            queue_receiver,
        ));
65
66
67
68

        Self { queue_sender }
    }

69
    /// Append an entry to the queue
70
    #[instrument(skip_all)]
71
72
73
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
74
        self.queue_sender
75
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
76
            .unwrap();
77
78
79
    }

    // Get the next batch
80
    #[instrument(skip(self))]
81
82
83
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
84
        max_size: Option<usize>,
85
        prefill_token_budget: u32,
86
        token_budget: u32,
87
88
89
90
91
92
93
94
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
95
                max_size,
96
                prefill_token_budget,
97
                token_budget,
98
                response_sender,
99
                span: Span::current(),
100
101
102
103
104
105
106
107
108
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
109
110
111
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
112
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: u32,
114
    max_batch_total_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
115
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
116
) {
117
118
119
120
121
122
123
    let mut state = State::new(
        requires_padding,
        block_size,
        window_size,
        speculate,
        max_batch_total_tokens,
    );
124

OlivierDehaene's avatar
OlivierDehaene committed
125
    while let Some(cmd) = receiver.recv().await {
126
        match cmd {
127
            QueueCommand::Append(entry, span) => {
128
                span.in_scope(|| state.append(*entry));
129
                metrics::gauge!("tgi_queue_size").increment(1.0);
130
            }
131
132
            QueueCommand::NextBatch {
                min_size,
133
                max_size,
134
                prefill_token_budget,
135
                token_budget,
136
                response_sender,
137
                span,
138
139
140
141
142
            } => {
                let next_batch = state
                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)
                    .instrument(span)
                    .await;
143
                response_sender.send(next_batch).unwrap();
144
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
145
            }
146
147
148
149
150
151
152
153
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
154
    entries: VecDeque<(u64, Entry)>,
155
156
157
158
159
160

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
161

162
163
    /// Paged Attention block size
    block_size: u32,
164
165
166

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169

    /// Speculation amount
    speculate: u32,
170
171
172

    /// Paged Attention Block Allocation
    block_allocator: Option<BlockAllocator>,
173
174
175
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
181
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
182
    ) -> Self {
183
184
185
        let block_allocator = (!requires_padding)
            .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));

186
        Self {
187
            entries: VecDeque::with_capacity(128),
188
189
            next_id: 0,
            next_batch_id: 0,
190
            block_size,
191
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
192
            speculate,
193
            block_allocator,
194
195
196
197
        }
    }

    /// Append an entry to the queue
198
199
200
201
202
203
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
204
        self.entries.push_back((self.next_id, entry));
205
206
207
208
        self.next_id += 1;
    }

    // Get the next batch
209
    async fn next_batch(
210
211
        &mut self,
        min_size: Option<usize>,
212
        max_size: Option<usize>,
213
214
215
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
216
        if self.entries.is_empty() {
217
            tracing::debug!("No queue");
218
219
220
221
222
223
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
224
                tracing::debug!("Not enough entries");
225
226
227
228
                return None;
            }
        }

229
230
231
232
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

233
        // Create span for this batch to add context to inference calls
234
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
235
236
        next_batch_span.follows_from(&Span::current());

237
        let mut batch_requests = Vec::with_capacity(self.entries.len());
238
        let mut batch_entries =
239
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
240

241
242
243
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;
244
        let mut max_blocks = 0;
245
246

        // Pop entries starting from the front of the queue
247
        'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
248
249
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
250
            if entry.response_tx.is_closed() {
251
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
252
                tracing::debug!("Dropping entry");
253
254
255
                continue;
            }

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            let block_allocation = match &self.block_allocator {
                None => {
                    // We pad to max input length in the Python shards
                    // We need to take these padding tokens into the equation
                    max_input_length = max_input_length.max(entry.request.input_length);
                    prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;

                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;
                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;

                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break 'entry_loop;
                    }
                    None
                }
                Some(block_allocator) => {
                    prefill_tokens += entry.request.input_length;
                    let max_new_tokens = match self.window_size {
                        None => entry.request.stopping_parameters.max_new_tokens,
                        Some(window_size) => min(
                            window_size.saturating_sub(entry.request.input_length),
                            entry.request.stopping_parameters.max_new_tokens,
                        ),
                    };
                    decode_tokens += max_new_tokens;

                    if prefill_tokens > prefill_token_budget
                        || (prefill_tokens + decode_tokens + self.speculate) > token_budget
                    {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break;
                    }

                    let tokens = entry.request.input_length
                        + entry.request.stopping_parameters.max_new_tokens
                        + self.speculate
                        - 1;

                    match block_allocator.allocate(tokens).await {
                        None => {
                            // Entry is over budget
                            // Add it back to the front
                            tracing::debug!("Over budget: not enough free blocks");
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                        Some(block_allocation) => {
                            tracing::debug!("Allocation: {block_allocation:?}");
                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
                            Some(block_allocation)
                        }
                    }
                }
            };
317

318
            tracing::debug!("Accepting entry");
319
320
321
322
323
324
325
326
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

327
328
329
330
331
332
333
334
335
336
            let (blocks, slots) = match &block_allocation {
                None => (Vec::new(), Vec::new()),
                Some(block_allocation) => (
                    block_allocation.blocks.clone(),
                    block_allocation.slots.clone(),
                ),
            };

            entry.block_allocation = block_allocation;

337
338
            batch_requests.push(Request {
                id,
339
                prefill_logprobs: entry.request.decoder_input_details,
340
341
342
                input_chunks: Some(Input {
                    chunks: entry.request.inputs.clone(),
                }),
343
                inputs: entry.request.inputs.chunks_to_string(),
344
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
345
346
347
348
349
350
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
351
                top_n_tokens: entry.request.top_n_tokens,
352
353
                blocks,
                slots,
drbh's avatar
drbh committed
354
                adapter_id: entry.request.adapter_id.clone(),
355
            });
356
357
358
359
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
360
361
362
363
364

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
365
366
        }

367
        // Empty batch
368
        if batch_requests.is_empty() {
369
            tracing::debug!("Filterered out all entries");
370
371
372
            return None;
        }

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
389
390
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
391
392
393
394

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
395
            size,
396
            max_tokens: (prefill_tokens + decode_tokens),
397
            max_blocks,
398
399
400
401
        };
        // Increment batch id
        self.next_batch_id += 1;

402
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
403

404
        Some((batch_entries, batch, next_batch_span))
405
406
407
    }
}

408
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
409
410
411

#[derive(Debug)]
enum QueueCommand {
412
    Append(Box<Entry>, Span),
413
414
    NextBatch {
        min_size: Option<usize>,
415
        max_size: Option<usize>,
416
        prefill_token_budget: u32,
417
        token_budget: u32,
418
        response_sender: oneshot::Sender<Option<NextBatch>>,
419
        span: Span,
420
421
422
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

460
461
462
#[cfg(test)]
mod tests {
    use super::*;
463
    use tracing::info_span;
464

465
466
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
467
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
468
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
469
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
470

471
        let entry = Entry {
472
            request: ValidGenerateRequest {
473
                inputs: vec![],
474
                input_length: 0,
475
                truncate: 0,
476
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
477
                parameters: ValidParameters {
478
479
480
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
481
                    typical_p: 0.0,
482
483
484
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
485
                    frequency_penalty: 0.0,
486
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
487
                    grammar: None,
488
                },
OlivierDehaene's avatar
OlivierDehaene committed
489
                stopping_parameters: ValidStoppingParameters {
490
                    ignore_eos_token: false,
491
                    max_new_tokens: 1,
492
493
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
494
                top_n_tokens: 0,
drbh's avatar
drbh committed
495
                adapter_id: None,
496
497
            },
            response_tx,
498
499
500
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
501
            batch_time: None,
502
            block_allocation: None,
503
504
        };
        (entry, receiver_tx)
505
506
    }

507
508
509
    #[tokio::test]
    async fn test_append() {
        let mut state = State::new(false, 1, None, 0, 16);
510
        let (entry, _guard) = default_entry();
511
512
513
514
515
516
517
518

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
519
        let (id, _) = state.entries.remove(0).unwrap();
520
521
522
        assert_eq!(id, 0);
    }

523
524
525
    #[tokio::test]
    async fn test_next_batch_empty() {
        let mut state = State::new(false, 1, None, 0, 16);
526

527
528
        assert!(state.next_batch(None, None, 1, 1).await.is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
529
530
    }

531
532
533
    #[tokio::test]
    async fn test_next_batch_min_size() {
        let mut state = State::new(false, 1, None, 0, 16);
534
535
536
537
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
538

539
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
540
541
542
543
544
545
546
547
548
549
550
551
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

552
553
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
554

555
        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
556
557
558

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
559
        let (id, _) = state.entries.remove(0).unwrap();
560
561
562
        assert_eq!(id, 2);
    }

563
564
565
    #[tokio::test]
    async fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0, 16);
566
567
568
569
570
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

571
        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
572
573
574
575
576
577
578
579
580
581
582
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

583
584
585
    #[tokio::test]
    async fn test_next_batch_token_budget() {
        let mut state = State::new(false, 1, None, 0, 2);
586
587
588
589
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
590

591
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
592
593
594
595
596
597
598
599
600
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

601
602
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
603

604
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
605
606
607
608
609
610
611
612
613
614
615
616
617
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
618
        let queue = Queue::new(false, 1, None, 0, 16);
619
620
        let (entry, _guard) = default_entry();
        queue.append(entry);
621
622
623
624
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
625
        let queue = Queue::new(false, 1, None, 0, 16);
626

627
628
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
629
630
631
632
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
633
        let queue = Queue::new(false, 1, None, 0, 16);
634
635
636
637
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
638

639
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
640
641
642
643
644
645
646
647
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

648
649
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
650

651
        // Not enough requests pending
652
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
653
        // Not enough token budget
654
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
655
        // Ok
656
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
657
658
659
660
661
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
662
663
    }

664
665
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
666
        let queue = Queue::new(false, 1, None, 0, 16);
667
668
669
670
671
672
673
674
675
676
677
678
679
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

680
    #[tokio::test]
681
    async fn test_queue_next_batch_token_budget() {
682
        let queue = Queue::new(false, 1, None, 0, 16);
683
684
685
686
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
687

688
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
689
690
691
692
693
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

694
695
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
696

697
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
698
699
700
701
702
703
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
704

Nicolas Patry's avatar
Nicolas Patry committed
705
706
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
707
        let queue = Queue::new(false, 1, None, 2, 16);
Nicolas Patry's avatar
Nicolas Patry committed
708
709
710
711
712
713
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
714
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
715

716
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
717
718
719
720
721
722
723
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

724
725
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
726
        let queue = Queue::new(false, 1, None, 0, 16);
727
728
729
        let (entry, _) = default_entry();
        queue.append(entry);

730
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
731
    }
732
}