queue.rs 25.7 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::client;
use crate::client::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
OlivierDehaene's avatar
OlivierDehaene committed
5
};
6
use nohash_hasher::{BuildNoHashHasher, IntMap};
7
use std::cmp::{max, min};
8
use std::collections::VecDeque;
Nicolas Patry's avatar
Nicolas Patry committed
9
10
11
12
13
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
    Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
    ValidStoppingParameters,
OlivierDehaene's avatar
OlivierDehaene committed
14
};
OlivierDehaene's avatar
OlivierDehaene committed
15
use tokio::sync::{mpsc, oneshot};
16
use tokio::time::Instant;
17
use tracing::{info_span, instrument, Instrument, Span};
18
19
20
21
22
23
24

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
25
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
26
27
28
29
30
31
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
32
33
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
34
35
    /// Block Allocation
    pub block_allocation: Option<BlockAllocation>,
36
37
38
39
40
41
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
42
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
43
44
45
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
50
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
51
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
52
    ) -> Self {
53
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
54
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
55
56

        // Launch background queue task
57
58
59
60
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
61
            speculate,
62
            max_batch_total_tokens,
63
64
            queue_receiver,
        ));
65
66
67
68

        Self { queue_sender }
    }

69
    /// Append an entry to the queue
70
    #[instrument(skip_all)]
71
72
73
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
74
        self.queue_sender
75
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
76
            .unwrap();
77
78
79
    }

    // Get the next batch
80
    #[instrument(skip(self))]
81
82
83
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
84
        max_size: Option<usize>,
85
        prefill_token_budget: u32,
86
        token_budget: u32,
87
88
89
90
91
92
93
94
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
95
                max_size,
96
                prefill_token_budget,
97
                token_budget,
98
                response_sender,
99
                span: Span::current(),
100
101
102
103
104
105
106
107
108
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
109
110
111
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
112
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: u32,
114
    max_batch_total_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
115
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
116
) {
117
118
119
120
121
122
123
    let mut state = State::new(
        requires_padding,
        block_size,
        window_size,
        speculate,
        max_batch_total_tokens,
    );
124

OlivierDehaene's avatar
OlivierDehaene committed
125
    while let Some(cmd) = receiver.recv().await {
126
        match cmd {
127
            QueueCommand::Append(entry, span) => {
128
                span.in_scope(|| state.append(*entry));
129
                metrics::gauge!("tgi_queue_size").increment(1.0);
130
            }
131
132
            QueueCommand::NextBatch {
                min_size,
133
                max_size,
134
                prefill_token_budget,
135
                token_budget,
136
                response_sender,
137
                span,
138
139
140
141
142
            } => {
                let next_batch = state
                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)
                    .instrument(span)
                    .await;
143
                response_sender.send(next_batch).unwrap();
144
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
145
            }
146
147
148
149
150
151
152
153
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
154
    entries: VecDeque<(u64, Entry)>,
155
156
157
158
159
160

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
161

162
163
    /// Paged Attention block size
    block_size: u32,
164
165
166

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169

    /// Speculation amount
    speculate: u32,
170
171
172

    /// Paged Attention Block Allocation
    block_allocator: Option<BlockAllocator>,
173
174
175
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
181
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
182
    ) -> Self {
183
184
185
        let block_allocator = (!requires_padding)
            .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));

186
        Self {
187
            entries: VecDeque::with_capacity(128),
188
189
            next_id: 0,
            next_batch_id: 0,
190
            block_size,
191
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
192
            speculate,
193
            block_allocator,
194
195
196
197
        }
    }

    /// Append an entry to the queue
198
199
200
201
202
203
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
204
        self.entries.push_back((self.next_id, entry));
205
206
207
208
        self.next_id += 1;
    }

    // Get the next batch
209
    async fn next_batch(
210
211
        &mut self,
        min_size: Option<usize>,
212
        max_size: Option<usize>,
213
214
215
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
216
        if self.entries.is_empty() {
217
            tracing::debug!("No queue");
218
219
220
221
222
223
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
224
                tracing::debug!("Not enough entries");
225
226
227
228
                return None;
            }
        }

drbh's avatar
drbh committed
229
230
231
232
233
234
235
        if let Some(max_size) = max_size {
            if max_size == 0 {
                tracing::debug!("No capacity");
                return None;
            }
        }

236
237
238
239
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

240
        // Create span for this batch to add context to inference calls
241
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
242
243
        next_batch_span.follows_from(&Span::current());

244
        let mut batch_requests = Vec::with_capacity(self.entries.len());
245
        let mut batch_entries =
246
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
247

248
249
250
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;
251
        let mut max_blocks = 0;
252
253

        // Pop entries starting from the front of the queue
254
        'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
255
256
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
257
            if entry.response_tx.is_closed() {
258
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
259
                tracing::debug!("Dropping entry");
260
261
262
                continue;
            }

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            let block_allocation = match &self.block_allocator {
                None => {
                    // We pad to max input length in the Python shards
                    // We need to take these padding tokens into the equation
                    max_input_length = max_input_length.max(entry.request.input_length);
                    prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;

                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;
                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;

                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break 'entry_loop;
                    }
                    None
                }
                Some(block_allocator) => {
                    prefill_tokens += entry.request.input_length;
                    let max_new_tokens = match self.window_size {
                        None => entry.request.stopping_parameters.max_new_tokens,
                        Some(window_size) => min(
                            window_size.saturating_sub(entry.request.input_length),
                            entry.request.stopping_parameters.max_new_tokens,
                        ),
                    };
                    decode_tokens += max_new_tokens;

                    if prefill_tokens > prefill_token_budget
                        || (prefill_tokens + decode_tokens + self.speculate) > token_budget
                    {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break;
                    }

                    let tokens = entry.request.input_length
                        + entry.request.stopping_parameters.max_new_tokens
                        + self.speculate
                        - 1;

                    match block_allocator.allocate(tokens).await {
                        None => {
                            // Entry is over budget
                            // Add it back to the front
                            tracing::debug!("Over budget: not enough free blocks");
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                        Some(block_allocation) => {
                            tracing::debug!("Allocation: {block_allocation:?}");
                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
                            Some(block_allocation)
                        }
                    }
                }
            };
324

325
            tracing::debug!("Accepting entry");
326
327
328
329
330
331
332
333
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

334
335
336
337
338
339
340
341
342
343
            let (blocks, slots) = match &block_allocation {
                None => (Vec::new(), Vec::new()),
                Some(block_allocation) => (
                    block_allocation.blocks.clone(),
                    block_allocation.slots.clone(),
                ),
            };

            entry.block_allocation = block_allocation;

344
345
            batch_requests.push(Request {
                id,
346
                prefill_logprobs: entry.request.decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
                input_chunks: Some(client::Input {
                    chunks: entry
                        .request
                        .inputs
                        .clone()
                        .into_iter()
                        .map(|c| client::InputChunk {
                            chunk: Some(match c {
                                Chunk::Text(text) => client::Chunk::Text(text),
                                Chunk::Image(image) => client::Chunk::Image(client::Image {
                                    data: image.data,
                                    mimetype: image.mimetype,
                                }),
                            }),
                        })
                        .collect(),
363
                }),
364
                inputs: entry.request.inputs.chunks_to_string(),
365
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
366
367
368
369
370
371
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
372
                top_n_tokens: entry.request.top_n_tokens,
373
374
                blocks,
                slots,
drbh's avatar
drbh committed
375
                adapter_id: entry.request.adapter_id.clone(),
376
            });
377
378
379
380
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
381
382
383
384
385

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
386
387
        }

388
        // Empty batch
389
        if batch_requests.is_empty() {
390
            tracing::debug!("Filterered out all entries");
391
392
393
            return None;
        }

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
410
411
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
412
413
414
415

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
416
            size,
417
            max_tokens: (prefill_tokens + decode_tokens),
418
            max_blocks,
419
420
421
422
        };
        // Increment batch id
        self.next_batch_id += 1;

423
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
424

425
        Some((batch_entries, batch, next_batch_span))
426
427
428
    }
}

429
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
430
431
432

#[derive(Debug)]
enum QueueCommand {
433
    Append(Box<Entry>, Span),
434
435
    NextBatch {
        min_size: Option<usize>,
436
        max_size: Option<usize>,
437
        prefill_token_budget: u32,
438
        token_budget: u32,
439
        response_sender: oneshot::Sender<Option<NextBatch>>,
440
        span: Span,
441
442
443
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

481
482
483
#[cfg(test)]
mod tests {
    use super::*;
484
    use tracing::info_span;
485

486
487
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
488
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
489
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
490
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
491

492
        let entry = Entry {
493
            request: ValidGenerateRequest {
494
                inputs: vec![],
495
                input_length: 0,
496
                truncate: 0,
497
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
498
                parameters: ValidParameters {
499
500
501
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
502
                    typical_p: 0.0,
503
504
505
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
506
                    frequency_penalty: 0.0,
507
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
508
                    grammar: None,
509
                },
OlivierDehaene's avatar
OlivierDehaene committed
510
                stopping_parameters: ValidStoppingParameters {
511
                    ignore_eos_token: false,
512
                    max_new_tokens: 1,
513
514
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
515
                top_n_tokens: 0,
drbh's avatar
drbh committed
516
                adapter_id: None,
517
518
            },
            response_tx,
519
520
521
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
522
            batch_time: None,
523
            block_allocation: None,
524
525
        };
        (entry, receiver_tx)
526
527
    }

528
529
530
    #[tokio::test]
    async fn test_append() {
        let mut state = State::new(false, 1, None, 0, 16);
531
        let (entry, _guard) = default_entry();
532
533
534
535
536
537
538
539

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
540
        let (id, _) = state.entries.remove(0).unwrap();
541
542
543
        assert_eq!(id, 0);
    }

544
545
546
    #[tokio::test]
    async fn test_next_batch_empty() {
        let mut state = State::new(false, 1, None, 0, 16);
547

548
549
        assert!(state.next_batch(None, None, 1, 1).await.is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
550
551
    }

552
553
554
    #[tokio::test]
    async fn test_next_batch_min_size() {
        let mut state = State::new(false, 1, None, 0, 16);
555
556
557
558
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
559

560
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
561
562
563
564
565
566
567
568
569
570
571
572
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

573
574
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
575

576
        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
577
578
579

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
580
        let (id, _) = state.entries.remove(0).unwrap();
581
582
583
        assert_eq!(id, 2);
    }

584
585
586
    #[tokio::test]
    async fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0, 16);
587
588
589
590
591
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

592
        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
593
594
595
596
597
598
599
600
601
602
603
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

604
605
606
    #[tokio::test]
    async fn test_next_batch_token_budget() {
        let mut state = State::new(false, 1, None, 0, 2);
607
608
609
610
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
611

612
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
613
614
615
616
617
618
619
620
621
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

622
623
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
624

625
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
626
627
628
629
630
631
632
633
634
635
636
637
638
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
639
        let queue = Queue::new(false, 1, None, 0, 16);
640
641
        let (entry, _guard) = default_entry();
        queue.append(entry);
642
643
644
645
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
646
        let queue = Queue::new(false, 1, None, 0, 16);
647

648
649
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
650
651
652
653
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
654
        let queue = Queue::new(false, 1, None, 0, 16);
655
656
657
658
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
659

660
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
661
662
663
664
665
666
667
668
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

669
670
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
671

672
        // Not enough requests pending
673
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
674
        // Not enough token budget
675
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
676
        // Ok
677
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
678
679
680
681
682
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
683
684
    }

685
686
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
687
        let queue = Queue::new(false, 1, None, 0, 16);
688
689
690
691
692
693
694
695
696
697
698
699
700
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

701
    #[tokio::test]
702
    async fn test_queue_next_batch_token_budget() {
703
        let queue = Queue::new(false, 1, None, 0, 16);
704
705
706
707
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
708

709
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
710
711
712
713
714
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

715
716
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
717

718
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
719
720
721
722
723
724
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
725

Nicolas Patry's avatar
Nicolas Patry committed
726
727
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
728
        let queue = Queue::new(false, 1, None, 2, 16);
Nicolas Patry's avatar
Nicolas Patry committed
729
730
731
732
733
734
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
735
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
736

737
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
738
739
740
741
742
743
744
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

745
746
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
747
        let queue = Queue::new(false, 1, None, 0, 16);
748
749
750
        let (entry, _) = default_entry();
        queue.append(entry);

751
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
752
    }
753
}