"benchmarks/utils.py" did not exist on "0a401b95b7f298b3d029576e1d65d99f06ed1228"
queue.rs 25.5 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::client;
use crate::client::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
OlivierDehaene's avatar
OlivierDehaene committed
5
};
6
use nohash_hasher::{BuildNoHashHasher, IntMap};
7
use std::cmp::{max, min};
8
use std::collections::VecDeque;
Nicolas Patry's avatar
Nicolas Patry committed
9
10
11
12
13
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
    Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
    ValidStoppingParameters,
OlivierDehaene's avatar
OlivierDehaene committed
14
};
OlivierDehaene's avatar
OlivierDehaene committed
15
use tokio::sync::{mpsc, oneshot};
16
use tokio::time::Instant;
17
use tracing::{info_span, instrument, Instrument, Span};
18
19
20
21
22
23
24

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
25
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
26
27
28
29
30
31
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
32
33
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
34
35
    /// Block Allocation
    pub block_allocation: Option<BlockAllocation>,
36
37
38
39
40
41
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
42
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
43
44
45
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
49
50
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
51
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
52
    ) -> Self {
53
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
54
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
55
56

        // Launch background queue task
57
58
59
60
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
61
            speculate,
62
            max_batch_total_tokens,
63
64
            queue_receiver,
        ));
65
66
67
68

        Self { queue_sender }
    }

69
    /// Append an entry to the queue
70
    #[instrument(skip_all)]
71
72
73
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
74
        self.queue_sender
75
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
76
            .unwrap();
77
78
79
    }

    // Get the next batch
80
    #[instrument(skip(self))]
81
82
83
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
84
        max_size: Option<usize>,
85
        prefill_token_budget: u32,
86
        token_budget: u32,
87
88
89
90
91
92
93
94
    ) -> Option<NextBatch> {
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
95
                max_size,
96
                prefill_token_budget,
97
                token_budget,
98
                response_sender,
99
                span: Span::current(),
100
101
102
103
104
105
106
107
108
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
109
110
111
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
112
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
113
    speculate: u32,
114
    max_batch_total_tokens: u32,
OlivierDehaene's avatar
OlivierDehaene committed
115
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
116
) {
117
118
119
120
121
122
123
    let mut state = State::new(
        requires_padding,
        block_size,
        window_size,
        speculate,
        max_batch_total_tokens,
    );
124

OlivierDehaene's avatar
OlivierDehaene committed
125
    while let Some(cmd) = receiver.recv().await {
126
        match cmd {
127
            QueueCommand::Append(entry, span) => {
128
                span.in_scope(|| state.append(*entry));
129
                metrics::gauge!("tgi_queue_size").increment(1.0);
130
            }
131
132
            QueueCommand::NextBatch {
                min_size,
133
                max_size,
134
                prefill_token_budget,
135
                token_budget,
136
                response_sender,
137
                span,
138
139
140
141
142
            } => {
                let next_batch = state
                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)
                    .instrument(span)
                    .await;
143
                response_sender.send(next_batch).unwrap();
144
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
145
            }
146
147
148
149
150
151
152
153
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
154
    entries: VecDeque<(u64, Entry)>,
155
156
157
158
159
160

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
161

162
163
    /// Paged Attention block size
    block_size: u32,
164
165
166

    /// Sliding window
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
167
168
169

    /// Speculation amount
    speculate: u32,
170
171
172

    /// Paged Attention Block Allocation
    block_allocator: Option<BlockAllocator>,
173
174
175
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
    fn new(
        requires_padding: bool,
        block_size: u32,
        window_size: Option<u32>,
        speculate: u32,
181
        max_batch_total_tokens: u32,
Nicolas Patry's avatar
Nicolas Patry committed
182
    ) -> Self {
183
184
185
        let block_allocator = (!requires_padding)
            .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));

186
        Self {
187
            entries: VecDeque::with_capacity(128),
188
189
            next_id: 0,
            next_batch_id: 0,
190
            block_size,
191
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
192
            speculate,
193
            block_allocator,
194
195
196
197
        }
    }

    /// Append an entry to the queue
198
199
200
201
202
203
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
204
        self.entries.push_back((self.next_id, entry));
205
206
207
208
        self.next_id += 1;
    }

    // Get the next batch
209
    async fn next_batch(
210
211
        &mut self,
        min_size: Option<usize>,
212
        max_size: Option<usize>,
213
214
215
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
216
        if self.entries.is_empty() {
217
            tracing::debug!("No queue");
218
219
220
221
222
223
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
224
                tracing::debug!("Not enough entries");
225
226
227
228
                return None;
            }
        }

229
230
231
232
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

233
        // Create span for this batch to add context to inference calls
234
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
235
236
        next_batch_span.follows_from(&Span::current());

237
        let mut batch_requests = Vec::with_capacity(self.entries.len());
238
        let mut batch_entries =
239
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
240

241
242
243
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;
244
        let mut max_blocks = 0;
245
246

        // Pop entries starting from the front of the queue
247
        'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
248
249
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
250
            if entry.response_tx.is_closed() {
251
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
252
                tracing::debug!("Dropping entry");
253
254
255
                continue;
            }

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            let block_allocation = match &self.block_allocator {
                None => {
                    // We pad to max input length in the Python shards
                    // We need to take these padding tokens into the equation
                    max_input_length = max_input_length.max(entry.request.input_length);
                    prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;

                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;
                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;

                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break 'entry_loop;
                    }
                    None
                }
                Some(block_allocator) => {
                    prefill_tokens += entry.request.input_length;
                    let max_new_tokens = match self.window_size {
                        None => entry.request.stopping_parameters.max_new_tokens,
                        Some(window_size) => min(
                            window_size.saturating_sub(entry.request.input_length),
                            entry.request.stopping_parameters.max_new_tokens,
                        ),
                    };
                    decode_tokens += max_new_tokens;

                    if prefill_tokens > prefill_token_budget
                        || (prefill_tokens + decode_tokens + self.speculate) > token_budget
                    {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break;
                    }

                    let tokens = entry.request.input_length
                        + entry.request.stopping_parameters.max_new_tokens
                        + self.speculate
                        - 1;

                    match block_allocator.allocate(tokens).await {
                        None => {
                            // Entry is over budget
                            // Add it back to the front
                            tracing::debug!("Over budget: not enough free blocks");
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                        Some(block_allocation) => {
                            tracing::debug!("Allocation: {block_allocation:?}");
                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
                            Some(block_allocation)
                        }
                    }
                }
            };
317

318
            tracing::debug!("Accepting entry");
319
320
321
322
323
324
325
326
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

327
328
329
330
331
332
333
334
335
336
            let (blocks, slots) = match &block_allocation {
                None => (Vec::new(), Vec::new()),
                Some(block_allocation) => (
                    block_allocation.blocks.clone(),
                    block_allocation.slots.clone(),
                ),
            };

            entry.block_allocation = block_allocation;

337
338
            batch_requests.push(Request {
                id,
339
                prefill_logprobs: entry.request.decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
                input_chunks: Some(client::Input {
                    chunks: entry
                        .request
                        .inputs
                        .clone()
                        .into_iter()
                        .map(|c| client::InputChunk {
                            chunk: Some(match c {
                                Chunk::Text(text) => client::Chunk::Text(text),
                                Chunk::Image(image) => client::Chunk::Image(client::Image {
                                    data: image.data,
                                    mimetype: image.mimetype,
                                }),
                            }),
                        })
                        .collect(),
356
                }),
357
                inputs: entry.request.inputs.chunks_to_string(),
358
                truncate: entry.request.truncate,
OlivierDehaene's avatar
OlivierDehaene committed
359
360
361
362
363
364
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
365
                top_n_tokens: entry.request.top_n_tokens,
366
367
                blocks,
                slots,
drbh's avatar
drbh committed
368
                adapter_id: entry.request.adapter_id.clone(),
369
            });
370
371
372
373
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
374
375
376
377
378

            // Check if max_size
            if Some(batch_requests.len()) == max_size {
                break;
            }
379
380
        }

381
        // Empty batch
382
        if batch_requests.is_empty() {
383
            tracing::debug!("Filterered out all entries");
384
385
386
            return None;
        }

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch_requests.len() < min_size {
                // Add back entries to the queue in the correct order
                for r in batch_requests.into_iter().rev() {
                    let id = r.id;
                    let entry = batch_entries.remove(&id).unwrap();
                    self.entries.push_front((id, entry));
                }

                return None;
            }
        }

        // Final batch size
403
404
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
405
406
407
408

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
409
            size,
410
            max_tokens: (prefill_tokens + decode_tokens),
411
            max_blocks,
412
413
414
415
        };
        // Increment batch id
        self.next_batch_id += 1;

416
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
417

418
        Some((batch_entries, batch, next_batch_span))
419
420
421
    }
}

422
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
423
424
425

#[derive(Debug)]
enum QueueCommand {
426
    Append(Box<Entry>, Span),
427
428
    NextBatch {
        min_size: Option<usize>,
429
        max_size: Option<usize>,
430
        prefill_token_budget: u32,
431
        token_budget: u32,
432
        response_sender: oneshot::Sender<Option<NextBatch>>,
433
        span: Span,
434
435
436
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

474
475
476
#[cfg(test)]
mod tests {
    use super::*;
477
    use tracing::info_span;
478

479
480
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
481
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
482
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
483
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
484

485
        let entry = Entry {
486
            request: ValidGenerateRequest {
487
                inputs: vec![],
488
                input_length: 0,
489
                truncate: 0,
490
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
491
                parameters: ValidParameters {
492
493
494
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
495
                    typical_p: 0.0,
496
497
498
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
499
                    frequency_penalty: 0.0,
500
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
501
                    grammar: None,
502
                },
OlivierDehaene's avatar
OlivierDehaene committed
503
                stopping_parameters: ValidStoppingParameters {
504
                    ignore_eos_token: false,
505
                    max_new_tokens: 1,
506
507
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
508
                top_n_tokens: 0,
drbh's avatar
drbh committed
509
                adapter_id: None,
510
511
            },
            response_tx,
512
513
514
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
515
            batch_time: None,
516
            block_allocation: None,
517
518
        };
        (entry, receiver_tx)
519
520
    }

521
522
523
    #[tokio::test]
    async fn test_append() {
        let mut state = State::new(false, 1, None, 0, 16);
524
        let (entry, _guard) = default_entry();
525
526
527
528
529
530
531
532

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
533
        let (id, _) = state.entries.remove(0).unwrap();
534
535
536
        assert_eq!(id, 0);
    }

537
538
539
    #[tokio::test]
    async fn test_next_batch_empty() {
        let mut state = State::new(false, 1, None, 0, 16);
540

541
542
        assert!(state.next_batch(None, None, 1, 1).await.is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
543
544
    }

545
546
547
    #[tokio::test]
    async fn test_next_batch_min_size() {
        let mut state = State::new(false, 1, None, 0, 16);
548
549
550
551
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
552

553
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
554
555
556
557
558
559
560
561
562
563
564
565
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

566
567
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
568

569
        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
570
571
572

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
573
        let (id, _) = state.entries.remove(0).unwrap();
574
575
576
        assert_eq!(id, 2);
    }

577
578
579
    #[tokio::test]
    async fn test_next_batch_max_size() {
        let mut state = State::new(false, 1, None, 0, 16);
580
581
582
583
584
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

585
        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
586
587
588
589
590
591
592
593
594
595
596
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

597
598
599
    #[tokio::test]
    async fn test_next_batch_token_budget() {
        let mut state = State::new(false, 1, None, 0, 2);
600
601
602
603
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
604

605
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
606
607
608
609
610
611
612
613
614
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

615
616
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
617

618
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
619
620
621
622
623
624
625
626
627
628
629
630
631
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
632
        let queue = Queue::new(false, 1, None, 0, 16);
633
634
        let (entry, _guard) = default_entry();
        queue.append(entry);
635
636
637
638
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
639
        let queue = Queue::new(false, 1, None, 0, 16);
640

641
642
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
643
644
645
646
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
647
        let queue = Queue::new(false, 1, None, 0, 16);
648
649
650
651
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
652

653
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
654
655
656
657
658
659
660
661
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

662
663
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
664

665
        // Not enough requests pending
666
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
667
        // Not enough token budget
668
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
669
        // Ok
670
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
671
672
673
674
675
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
676
677
    }

678
679
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
680
        let queue = Queue::new(false, 1, None, 0, 16);
681
682
683
684
685
686
687
688
689
690
691
692
693
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

694
    #[tokio::test]
695
    async fn test_queue_next_batch_token_budget() {
696
        let queue = Queue::new(false, 1, None, 0, 16);
697
698
699
700
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
701

702
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
703
704
705
706
707
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

708
709
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
710

711
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
712
713
714
715
716
717
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
718

Nicolas Patry's avatar
Nicolas Patry committed
719
720
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
721
        let queue = Queue::new(false, 1, None, 2, 16);
Nicolas Patry's avatar
Nicolas Patry committed
722
723
724
725
726
727
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
728
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
729

730
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
731
732
733
734
735
736
737
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

738
739
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
740
        let queue = Queue::new(false, 1, None, 0, 16);
741
742
743
        let (entry, _) = default_entry();
        queue.append(entry);

744
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
745
    }
746
}