queue.rs 24.8 KB
Newer Older
1
2
3
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
OlivierDehaene's avatar
OlivierDehaene committed
4
5
6
use crate::validation::{
    ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
7
use nohash_hasher::{BuildNoHashHasher, IntMap};
8
use std::cmp::{max, min};
9
use std::collections::VecDeque;
OlivierDehaene's avatar
OlivierDehaene committed
10
11
12
use text_generation_client::v3::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
13
14
use text_generation_client::ChunksToString;
use text_generation_client::Input;
OlivierDehaene's avatar
OlivierDehaene committed
15
use tokio::sync::{mpsc, oneshot};
16
use tokio::time::Instant;
17
use tracing::{info_span, instrument, Instrument, Span};
18
19
20
21
22
23
24

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
25
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
26
27
28
29
30
31
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
32
33
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
34
35
    /// Block Allocation
    pub block_allocation: Option<BlockAllocation>,
36
37
38
39
40
41
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
42
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
43
44
45
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
50
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
51
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
52
    ) -> Self {
53
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
54
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
55
56

        // Launch background queue task
57
58
59
60
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
61
            speculate,
62
            max_batch_total_tokens,
63
64
            queue_receiver,
        ));
65
66
67
68

        Self { queue_sender }
    }

69
    /// Append an entry to the queue
70
    #[instrument(skip_all)]
71
72
73
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
74
        self.queue_sender
75
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
76
            .unwrap();
77
78
79
    }

    // Get the next batch
80
    #[instrument(skip(self))]
81
82
83
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
84
        max_size: Option<usize>,
85
        prefill_token_budget: u32,
86
        token_budget: u32,
87
88
89
90
91
92
93
94
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
95
                max_size,
96
                prefill_token_budget,
97
                token_budget,
98
                response_sender,
99
                span: Span::current(),
100
101
102
103
104
105
106
107
108
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
109
110
111
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
112
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: u32,
114
    max_batch_total_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
115
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
116
) {
117
118
119
120
121
122
123
    let mut state = State::new(
        requires_padding,
        block_size,
        window_size,
        speculate,
        max_batch_total_tokens,
    );
124

OlivierDehaene's avatar
OlivierDehaene committed
125
    while let Some(cmd) = receiver.recv().await {
126
        match cmd {
127
            QueueCommand::Append(entry, span) => {
128
                span.in_scope(|| state.append(*entry));
129
130
                metrics::increment_gauge!("tgi_queue_size", 1.0);
            }
131
132
            QueueCommand::NextBatch {
                min_size,
133
                max_size,
134
                prefill_token_budget,
135
                token_budget,
136
                response_sender,
137
                span,
138
139
140
141
142
            } => {
                let next_batch = state
                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)
                    .instrument(span)
                    .await;
143
                response_sender.send(next_batch).unwrap();
144
                metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
145
            }
146
147
148
149
150
151
152
153
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
154
    entries: VecDeque<(u64, Entry)>,
155
156
157
158
159
160

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
161

162
163
    /// Paged Attention block size
    block_size: u32,
164
165
166

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169

    /// Speculation amount
    speculate: u32,
170
171
172

    /// Paged Attention Block Allocation
    block_allocator: Option<BlockAllocator>,
173
174
175
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
181
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
182
    ) -> Self {
183
184
185
        let block_allocator = (!requires_padding)
            .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));

186
        Self {
187
            entries: VecDeque::with_capacity(128),
188
189
            next_id: 0,
            next_batch_id: 0,
190
            block_size,
191
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
192
            speculate,
193
            block_allocator,
194
195
196
197
        }
    }

    /// Append an entry to the queue
198
199
200
201
202
203
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
204
        self.entries.push_back((self.next_id, entry));
205
206
207
208
        self.next_id += 1;
    }

    // Get the next batch
209
    async fn next_batch(
210
211
        &mut self,
        min_size: Option<usize>,
212
        max_size: Option<usize>,
213
214
215
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
216
        if self.entries.is_empty() {
217
            tracing::debug!("No queue");
218
219
220
221
222
223
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
224
                tracing::debug!("Not enough entries");
225
226
227
228
                return None;
            }
        }

229
230
231
232
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

233
        // Create span for this batch to add context to inference calls
234
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
235
236
        next_batch_span.follows_from(&Span::current());

237
        let mut batch_requests = Vec::with_capacity(self.entries.len());
238
        let mut batch_entries =
239
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
240

241
242
243
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;
244
        let mut max_blocks = 0;
245
246

        // Pop entries starting from the front of the queue
247
        'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
248
249
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
250
            if entry.response_tx.is_closed() {
251
                metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
252
                tracing::debug!("Dropping entry");
253
254
255
                continue;
            }

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            let block_allocation = match &self.block_allocator {
                None => {
                    // We pad to max input length in the Python shards
                    // We need to take these padding tokens into the equation
                    max_input_length = max_input_length.max(entry.request.input_length);
                    prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;

                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;
                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;

                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break 'entry_loop;
                    }
                    None
                }
                Some(block_allocator) => {
                    prefill_tokens += entry.request.input_length;
                    let max_new_tokens = match self.window_size {
                        None => entry.request.stopping_parameters.max_new_tokens,
                        Some(window_size) => min(
                            window_size.saturating_sub(entry.request.input_length),
                            entry.request.stopping_parameters.max_new_tokens,
                        ),
                    };
                    decode_tokens += max_new_tokens;

                    if prefill_tokens > prefill_token_budget
                        || (prefill_tokens + decode_tokens + self.speculate) > token_budget
                    {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break;
                    }

                    let tokens = entry.request.input_length
                        + entry.request.stopping_parameters.max_new_tokens
                        + self.speculate
                        - 1;

                    match block_allocator.allocate(tokens).await {
                        None => {
                            // Entry is over budget
                            // Add it back to the front
                            tracing::debug!("Over budget: not enough free blocks");
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                        Some(block_allocation) => {
                            tracing::debug!("Allocation: {block_allocation:?}");
                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
                            Some(block_allocation)
                        }
                    }
                }
            };
317

318
            tracing::debug!("Accepting entry");
319
320
321
322
323
324
325
326
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

327
328
329
330
331
332
333
334
335
336
            let (blocks, slots) = match &block_allocation {
                None => (Vec::new(), Vec::new()),
                Some(block_allocation) => (
                    block_allocation.blocks.clone(),
                    block_allocation.slots.clone(),
                ),
            };

            entry.block_allocation = block_allocation;

337
338
            batch_requests.push(Request {
                id,
339
                prefill_logprobs: entry.request.decoder_input_details,
340
341
342
                input_chunks: Some(Input {
                    chunks: entry.request.inputs.clone(),
                }),
343
                inputs: entry.request.inputs.chunks_to_string(),
344
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
345
346
347
348
349
350
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
351
                top_n_tokens: entry.request.top_n_tokens,
352
353
                blocks,
                slots,
354
            });
355
356
357
358
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
359
360
361
362
363

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
364
365
        }

366
        // Empty batch
367
        if batch_requests.is_empty() {
368
            tracing::debug!("Filterered out all entries");
369
370
371
            return None;
        }

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
388
389
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
390
391
392
393

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
394
            size,
395
            max_tokens: (prefill_tokens + decode_tokens),
396
            max_blocks,
397
398
399
400
        };
        // Increment batch id
        self.next_batch_id += 1;

401
        metrics::histogram!("tgi_batch_next_size", batch.size as f64);
402

403
        Some((batch_entries, batch, next_batch_span))
404
405
406
    }
}

407
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
408
409
410

#[derive(Debug)]
enum QueueCommand {
411
    Append(Box<Entry>, Span),
412
413
    NextBatch {
        min_size: Option<usize>,
414
        max_size: Option<usize>,
415
        prefill_token_budget: u32,
416
        token_budget: u32,
417
        response_sender: oneshot::Sender<Option<NextBatch>>,
418
        span: Span,
419
420
421
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

459
460
461
#[cfg(test)]
mod tests {
    use super::*;
462
    use tracing::info_span;
463

464
465
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
466
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
467
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
468
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
469

470
        let entry = Entry {
471
            request: ValidGenerateRequest {
472
                inputs: vec![],
473
                input_length: 0,
474
                truncate: 0,
475
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
476
                parameters: ValidParameters {
477
478
479
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
480
                    typical_p: 0.0,
481
482
483
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
484
                    frequency_penalty: 0.0,
485
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
486
                    grammar: None,
487
                },
OlivierDehaene's avatar
OlivierDehaene committed
488
                stopping_parameters: ValidStoppingParameters {
489
                    ignore_eos_token: false,
490
                    max_new_tokens: 1,
491
492
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
493
                top_n_tokens: 0,
494
495
            },
            response_tx,
496
497
498
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
499
            batch_time: None,
500
            block_allocation: None,
501
502
        };
        (entry, receiver_tx)
503
504
    }

505
506
507
    #[tokio::test]
    async fn test_append() {
        let mut state = State::new(false, 1, None, 0, 16);
508
        let (entry, _guard) = default_entry();
509
510
511
512
513
514
515
516

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
517
        let (id, _) = state.entries.remove(0).unwrap();
518
519
520
        assert_eq!(id, 0);
    }

521
522
523
    #[tokio::test]
    async fn test_next_batch_empty() {
        let mut state = State::new(false, 1, None, 0, 16);
524

525
526
        assert!(state.next_batch(None, None, 1, 1).await.is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
527
528
    }

529
530
531
    #[tokio::test]
    async fn test_next_batch_min_size() {
        let mut state = State::new(false, 1, None, 0, 16);
532
533
534
535
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
536

537
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
538
539
540
541
542
543
544
545
546
547
548
549
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

550
551
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
552

553
        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
554
555
556

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
557
        let (id, _) = state.entries.remove(0).unwrap();
558
559
560
        assert_eq!(id, 2);
    }

561
562
563
    #[tokio::test]
    async fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0, 16);
564
565
566
567
568
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

569
        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
570
571
572
573
574
575
576
577
578
579
580
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

581
582
583
    #[tokio::test]
    async fn test_next_batch_token_budget() {
        let mut state = State::new(false, 1, None, 0, 2);
584
585
586
587
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
588

589
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
590
591
592
593
594
595
596
597
598
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

599
600
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
601

602
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
603
604
605
606
607
608
609
610
611
612
613
614
615
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
616
        let queue = Queue::new(false, 1, None, 0, 16);
617
618
        let (entry, _guard) = default_entry();
        queue.append(entry);
619
620
621
622
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
623
        let queue = Queue::new(false, 1, None, 0, 16);
624

625
626
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
627
628
629
630
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
631
        let queue = Queue::new(false, 1, None, 0, 16);
632
633
634
635
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
636

637
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
638
639
640
641
642
643
644
645
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

646
647
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
648

649
        // Not enough requests pending
650
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
651
        // Not enough token budget
652
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
653
        // Ok
654
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
655
656
657
658
659
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
660
661
    }

662
663
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
664
        let queue = Queue::new(false, 1, None, 0, 16);
665
666
667
668
669
670
671
672
673
674
675
676
677
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

678
    #[tokio::test]
679
    async fn test_queue_next_batch_token_budget() {
680
        let queue = Queue::new(false, 1, None, 0, 16);
681
682
683
684
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
685

686
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
687
688
689
690
691
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

692
693
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
694

695
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
696
697
698
699
700
701
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
702

Nicolas Patry's avatar
Nicolas Patry committed
703
704
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
705
        let queue = Queue::new(false, 1, None, 2, 16);
Nicolas Patry's avatar
Nicolas Patry committed
706
707
708
709
710
711
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
712
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
713

714
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
715
716
717
718
719
720
721
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

722
723
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
724
        let queue = Queue::new(false, 1, None, 0, 16);
725
726
727
        let (entry, _) = default_entry();
        queue.append(entry);

728
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
729
    }
730
}