queue.rs 28.6 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::client;
use crate::client::{
    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
OlivierDehaene's avatar
OlivierDehaene committed
5
};
6
use nohash_hasher::{BuildNoHashHasher, IntMap};
7
use std::cmp::max;
8
use std::collections::VecDeque;
Nicolas Patry's avatar
Nicolas Patry committed
9
10
11
12
13
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
    Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
    ValidStoppingParameters,
OlivierDehaene's avatar
OlivierDehaene committed
14
};
OlivierDehaene's avatar
OlivierDehaene committed
15
use tokio::sync::{mpsc, oneshot};
16
use tokio::time::Instant;
17
use tracing::{info_span, instrument, Instrument, Span};
18
19
20
21
22
23
24

/// Queue entry
#[derive(Debug)]
pub(crate) struct Entry {
    /// Request
    pub request: ValidGenerateRequest,
    /// Response sender to communicate between the Infer struct and the batching_task
OlivierDehaene's avatar
OlivierDehaene committed
25
    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
26
27
28
29
30
31
    /// Span that will live as long as entry
    pub span: Span,
    /// Temporary span used as a guard when logging inference, wait times...
    pub temp_span: Option<Span>,
    /// Instant when this entry was queued
    pub queue_time: Instant,
32
33
    /// Instant when this entry was added to a batch
    pub batch_time: Option<Instant>,
34
35
    /// Block Allocation
    pub block_allocation: Option<BlockAllocation>,
36
37
38
39
40
41
}

/// Request Queue
#[derive(Debug, Clone)]
pub(crate) struct Queue {
    /// Channel to communicate with the background queue task
OlivierDehaene's avatar
OlivierDehaene committed
42
    queue_sender: mpsc::UnboundedSender<QueueCommand>,
43
44
45
}

impl Queue {
Nicolas Patry's avatar
Nicolas Patry committed
46
47
48
    pub(crate) fn new(
        requires_padding: bool,
        block_size: u32,
49
        prefix_caching: bool,
Nicolas Patry's avatar
Nicolas Patry committed
50
51
        window_size: Option<u32>,
        speculate: u32,
52
        max_batch_total_tokens: u32,
53
        support_chunking: bool,
Nicolas Patry's avatar
Nicolas Patry committed
54
    ) -> Self {
55
        // Create channel
OlivierDehaene's avatar
OlivierDehaene committed
56
        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
57
58

        // Launch background queue task
59
60
61
        tokio::spawn(queue_task(
            requires_padding,
            block_size,
62
            prefix_caching,
63
            window_size,
Nicolas Patry's avatar
Nicolas Patry committed
64
            speculate,
65
            max_batch_total_tokens,
66
            support_chunking,
67
68
            queue_receiver,
        ));
69
70
71
72

        Self { queue_sender }
    }

73
    /// Append an entry to the queue
74
    #[instrument(skip_all)]
75
76
77
    pub(crate) fn append(&self, entry: Entry) {
        // Send append command to the background task managing the state
        // Unwrap is safe here
78
        self.queue_sender
79
            .send(QueueCommand::Append(Box::new(entry), Span::current()))
80
            .unwrap();
81
82
83
    }

    // Get the next batch
84
    #[instrument(skip(self))]
85
86
87
    pub(crate) async fn next_batch(
        &self,
        min_size: Option<usize>,
88
        max_size: Option<usize>,
89
        prefill_token_budget: u32,
90
        token_budget: u32,
91
    ) -> Option<NextBatch> {
92
93
94
95
        if prefill_token_budget == 0 || token_budget == 0 {
            return None;
        };

96
97
98
99
100
101
102
        // Create response channel
        let (response_sender, response_receiver) = oneshot::channel();
        // Send next batch command to the background task managing the state
        // Unwrap is safe here
        self.queue_sender
            .send(QueueCommand::NextBatch {
                min_size,
103
                max_size,
104
                prefill_token_budget,
105
                token_budget,
106
                response_sender,
107
                span: Span::current(),
108
109
110
111
112
113
114
115
116
            })
            .unwrap();
        // Await on response channel
        // Unwrap is safe here
        response_receiver.await.unwrap()
    }
}

// Background task responsible of the queue state
117
#[allow(clippy::too_many_arguments)]
118
119
120
async fn queue_task(
    requires_padding: bool,
    block_size: u32,
121
    prefix_caching: bool,
122
    window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
123
    speculate: u32,
124
    max_batch_total_tokens: u32,
125
    support_chunking: bool,
OlivierDehaene's avatar
OlivierDehaene committed
126
    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
127
) {
128
129
130
    let mut state = State::new(
        requires_padding,
        block_size,
131
        prefix_caching,
132
133
134
        window_size,
        speculate,
        max_batch_total_tokens,
135
        support_chunking,
136
    );
137

OlivierDehaene's avatar
OlivierDehaene committed
138
    while let Some(cmd) = receiver.recv().await {
139
        match cmd {
140
            QueueCommand::Append(entry, span) => {
141
                span.in_scope(|| state.append(*entry));
142
                metrics::gauge!("tgi_queue_size").increment(1.0);
143
            }
144
145
            QueueCommand::NextBatch {
                min_size,
146
                max_size,
147
                prefill_token_budget,
148
                token_budget,
149
                response_sender,
150
                span,
151
152
153
154
155
            } => {
                let next_batch = state
                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)
                    .instrument(span)
                    .await;
156
                response_sender.send(next_batch).unwrap();
157
                metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
158
            }
159
160
161
162
163
164
165
166
        }
    }
}

/// Queue State
#[derive(Debug)]
struct State {
    /// Queue entries organized in a Vec
167
    entries: VecDeque<(u64, Entry)>,
168
169
170
171
172
173

    /// Id of the next entry
    next_id: u64,

    /// Id of the next batch
    next_batch_id: u64,
174

175
176
    /// Paged Attention block size
    block_size: u32,
177

Nicolas Patry's avatar
Nicolas Patry committed
178
179
    /// Speculation amount
    speculate: u32,
180

181
182
183
184
185
    /// Whether the model allow the prefill chunking
    /// If it does, the last request in the batch will be split to exactly match the prefill
    /// token budget
    support_chunking: bool,

186
187
    /// Paged Attention Block Allocation
    block_allocator: Option<BlockAllocator>,
188
189
190
}

impl State {
Nicolas Patry's avatar
Nicolas Patry committed
191
192
193
    fn new(
        requires_padding: bool,
        block_size: u32,
194
        prefix_caching: bool,
Nicolas Patry's avatar
Nicolas Patry committed
195
196
        window_size: Option<u32>,
        speculate: u32,
197
        max_batch_total_tokens: u32,
198
        support_chunking: bool,
Nicolas Patry's avatar
Nicolas Patry committed
199
    ) -> Self {
200
201
202
203
204
205
206
207
        let block_allocator = (!requires_padding).then(|| {
            BlockAllocator::new(
                max_batch_total_tokens,
                block_size,
                prefix_caching,
                window_size,
            )
        });
208

209
        Self {
210
            entries: VecDeque::with_capacity(128),
211
212
            next_id: 0,
            next_batch_id: 0,
213
            block_size,
Nicolas Patry's avatar
Nicolas Patry committed
214
            speculate,
215
            support_chunking,
216
            block_allocator,
217
218
219
220
        }
    }

    /// Append an entry to the queue
221
222
223
224
225
226
    fn append(&mut self, mut entry: Entry) {
        // Create a span that will live as long as the entry is in the queue waiting to be batched
        let queue_span = info_span!(parent: &entry.span, "queued");
        entry.temp_span = Some(queue_span);

        // Push entry in the queue
227
        self.entries.push_back((self.next_id, entry));
228
229
230
231
        self.next_id += 1;
    }

    // Get the next batch
232
    async fn next_batch(
233
234
        &mut self,
        min_size: Option<usize>,
235
        max_size: Option<usize>,
236
237
238
        prefill_token_budget: u32,
        token_budget: u32,
    ) -> Option<NextBatch> {
239
        if self.entries.is_empty() {
240
            tracing::debug!("No queue");
241
242
243
244
245
246
            return None;
        }

        // Check if we have enough entries
        if let Some(min_size) = min_size {
            if self.entries.len() < min_size {
247
                tracing::debug!("Not enough entries");
248
249
250
251
                return None;
            }
        }

drbh's avatar
drbh committed
252
253
254
255
256
257
258
        if let Some(max_size) = max_size {
            if max_size == 0 {
                tracing::debug!("No capacity");
                return None;
            }
        }

259
260
261
262
        // Pad prefill_token_budget to be a multiple of block size
        let prefill_token_budget =
            ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size;

263
        // Create span for this batch to add context to inference calls
264
        let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
265
        next_batch_span.follows_from(Span::current());
266

267
        let mut batch = Vec::with_capacity(self.entries.len());
268
269
270
        let mut max_input_length = 0;
        let mut prefill_tokens: u32 = 0;
        let mut decode_tokens: u32 = 0;
271
        let mut max_blocks = 0;
272
273

        // Pop entries starting from the front of the queue
274
        'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
275
276
            // Filter entries where the response receiver was dropped (== entries where the request
            // was dropped by the client)
OlivierDehaene's avatar
OlivierDehaene committed
277
            if entry.response_tx.is_closed() {
278
                metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
279
                tracing::debug!("Dropping entry");
280
281
282
                continue;
            }

283
284
285
286
287
            let block_allocation = match &self.block_allocator {
                None => {
                    // We pad to max input length in the Python shards
                    // We need to take these padding tokens into the equation
                    max_input_length = max_input_length.max(entry.request.input_length);
288
                    prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
289
290
291
292
293
294
295
296
297
298
299
300
301

                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;
                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;

                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
                        // Entry is over budget
                        // Add it back to the front
                        tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
                        self.entries.push_front((id, entry));
                        break 'entry_loop;
                    }
                    None
                }
302
                Some(block_allocator) => {
Nicolas Patry's avatar
Nicolas Patry committed
303
304
305
306
307
308
309
310
                    // If users wants the prefill logprobs, we cannot reuse the cache.
                    // So no input_ids for the radix tree.
                    let input_ids = if entry.request.decoder_input_details {
                        None
                    } else {
                        entry.request.input_ids.clone()
                    };

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
                    let tokens = entry.request.input_length
                        + entry.request.stopping_parameters.max_new_tokens
                        + self.speculate
                        - 1;
                    tracing::debug!("Allocating {tokens} with {input_ids:?}");

                    let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
                        None => {
                            // Entry is over budget
                            // Add it back to the front
                            tracing::debug!("Over budget: not enough free blocks");
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                        Some(mut block_allocation) => {
                            tracing::debug!("Allocation: {block_allocation:?}");
                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);

                            if block_allocation.prefix_len == entry.request.input_length {
                                // The whole request was found in the radix trie
                                // However, for the transformer forward to work, we need to
                                // have at least one token of postfix.
                                block_allocation.prefix_len -= 1;
                            }

                            block_allocation
                        }
                    };

                    let postfix_len = entry.request.input_length - block_allocation.prefix_len;

                    if prefill_tokens + postfix_len > prefill_token_budget {
                        // Entry is over budget
                        if self.support_chunking {
                            // We support chunking, just set postfix_len to exactly match prefill_token_budget
                            let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
                            if chunk_len > 0 {
                                // Push this entry inside the batch
                                batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
                            } else {
                                // We cannot prefill even one token for this entry
                                // Add it back to the queue
                                self.entries.push_front((id, entry));
                            }
                            tracing::debug!(
                                "Matched budget: prefill_tokens={} == {prefill_token_budget}",
                                prefill_tokens + postfix_len
                            );
                            break 'entry_loop;
                        } else {
                            // We don't support chunking, this entry needs to go back to the buffer
                            // Add it back to the front
                            tracing::debug!(
                                "Over budget: prefill_tokens={} > {prefill_token_budget}",
                                prefill_tokens + postfix_len
                            );
                            self.entries.push_front((id, entry));
                            break 'entry_loop;
                        }
                    }

                    prefill_tokens += postfix_len;

                    Some(block_allocation)
375
376
                }
            };
377
            batch.push((id, entry, block_allocation, None));
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            if Some(batch.len()) == max_size {
                break;
            }
        }

        // Empty batch
        if batch.is_empty() {
            tracing::debug!("Filterered out all entries");
            return None;
        }

        // XXX We haven't allocated yet, so we're allowed to ditch the results.
        // Check if our batch is big enough
        if let Some(min_size) = min_size {
            // Batch is too small
            if batch.len() < min_size {
                // Add back entries to the queue in the correct order
395
                for (id, entry, _, _) in batch.into_iter().rev() {
396
397
398
399
400
401
402
403
404
                    self.entries.push_front((id, entry));
                }
                return None;
            }
        }

        let mut batch_requests = Vec::with_capacity(self.entries.len());
        let mut batch_entries =
            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
405

406
        for (id, mut entry, block_allocation, chunk_len) in batch {
407
408
409
410
411
412
413
414
            // Create a new span to link the batch back to this entry
            let entry_batch_span = info_span!(parent: &entry.span, "infer");
            // Add relationships
            next_batch_span.follows_from(&entry_batch_span);
            entry_batch_span.follows_from(&next_batch_span);
            // Update entry
            entry.temp_span = Some(entry_batch_span);

415
416
            let (blocks, slots, prefix_len) = match &block_allocation {
                None => (Vec::new(), Vec::new(), 0),
417
418
419
                Some(block_allocation) => (
                    block_allocation.blocks.clone(),
                    block_allocation.slots.clone(),
420
                    block_allocation.prefix_len,
421
422
423
424
425
                ),
            };

            entry.block_allocation = block_allocation;

426
427
            batch_requests.push(Request {
                id,
428
                prefill_logprobs: entry.request.decoder_input_details,
Nicolas Patry's avatar
Nicolas Patry committed
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                input_chunks: Some(client::Input {
                    chunks: entry
                        .request
                        .inputs
                        .clone()
                        .into_iter()
                        .map(|c| client::InputChunk {
                            chunk: Some(match c {
                                Chunk::Text(text) => client::Chunk::Text(text),
                                Chunk::Image(image) => client::Chunk::Image(client::Image {
                                    data: image.data,
                                    mimetype: image.mimetype,
                                }),
                            }),
                        })
                        .collect(),
445
                }),
446
                inputs: entry.request.inputs.chunks_to_string(),
447
                truncate: entry.request.truncate,
448
                add_special_tokens: entry.request.add_special_tokens,
OlivierDehaene's avatar
OlivierDehaene committed
449
450
451
452
453
454
                parameters: Some(NextTokenChooserParameters::from(
                    entry.request.parameters.clone(),
                )),
                stopping_parameters: Some(StoppingCriteriaParameters::from(
                    entry.request.stopping_parameters.clone(),
                )),
Nicolas Patry's avatar
Nicolas Patry committed
455
                top_n_tokens: entry.request.top_n_tokens,
456
457
                blocks,
                slots,
458
                cache_len: prefix_len,
drbh's avatar
drbh committed
459
                adapter_id: entry.request.adapter_id.clone(),
460
                chunk_len,
461
            });
462
463
464
465
            // Set batch_time
            entry.batch_time = Some(Instant::now());
            // Insert in batch_entries IntMap
            batch_entries.insert(id, entry);
466
467
468
        }

        // Final batch size
469
470
        let size = batch_requests.len() as u32;
        next_batch_span.record("batch_size", size);
471
472
473
474

        let batch = Batch {
            id: self.next_batch_id,
            requests: batch_requests,
475
            size,
476
            max_tokens: (prefill_tokens + decode_tokens),
477
            max_blocks,
478
479
480
481
        };
        // Increment batch id
        self.next_batch_id += 1;

482
        metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
483

484
        Some((batch_entries, batch, next_batch_span))
485
486
487
    }
}

488
type NextBatch = (IntMap<u64, Entry>, Batch, Span);
489
490
491

#[derive(Debug)]
enum QueueCommand {
492
    Append(Box<Entry>, Span),
493
494
    NextBatch {
        min_size: Option<usize>,
495
        max_size: Option<usize>,
496
        prefill_token_budget: u32,
497
        token_budget: u32,
498
        response_sender: oneshot::Sender<Option<NextBatch>>,
499
        span: Span,
500
501
502
    },
}

OlivierDehaene's avatar
OlivierDehaene committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
impl From<ValidParameters> for NextTokenChooserParameters {
    fn from(value: ValidParameters) -> Self {
        let (grammar, grammar_type) = match value.grammar {
            None => (String::new(), GrammarType::None),

            Some(grammar) => match grammar {
                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),
                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),
            },
        };

        Self {
            temperature: value.temperature,
            top_k: value.top_k,
            top_p: value.top_p,
            typical_p: value.typical_p,
            do_sample: value.do_sample,
            seed: value.seed,
            repetition_penalty: value.repetition_penalty,
            frequency_penalty: value.frequency_penalty,
            watermark: value.watermark,
            grammar,
            grammar_type: grammar_type.into(),
        }
    }
}

impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
    fn from(value: ValidStoppingParameters) -> Self {
        Self {
            max_new_tokens: value.max_new_tokens,
            stop_sequences: value.stop_sequences,
            ignore_eos_token: value.ignore_eos_token,
        }
    }
}

540
541
#[cfg(test)]
mod tests {
542
543
    use std::sync::Arc;

544
    use super::*;
545
    use tracing::info_span;
546

547
548
    fn default_entry() -> (
        Entry,
OlivierDehaene's avatar
OlivierDehaene committed
549
        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,
550
    ) {
OlivierDehaene's avatar
OlivierDehaene committed
551
        let (response_tx, receiver_tx) = mpsc::unbounded_channel();
552

553
        let entry = Entry {
554
            request: ValidGenerateRequest {
555
                inputs: vec![],
556
                input_ids: Some(Arc::new(vec![])),
557
                input_length: 1,
558
                add_special_tokens: true,
559
                truncate: 0,
560
                decoder_input_details: false,
OlivierDehaene's avatar
OlivierDehaene committed
561
                parameters: ValidParameters {
562
563
564
                    temperature: 0.0,
                    top_k: 0,
                    top_p: 0.0,
565
                    typical_p: 0.0,
566
567
568
                    do_sample: false,
                    seed: 0,
                    repetition_penalty: 0.0,
569
                    frequency_penalty: 0.0,
570
                    watermark: false,
OlivierDehaene's avatar
OlivierDehaene committed
571
                    grammar: None,
572
                },
OlivierDehaene's avatar
OlivierDehaene committed
573
                stopping_parameters: ValidStoppingParameters {
574
                    ignore_eos_token: false,
575
                    max_new_tokens: 1,
576
                    max_total_new_tokens: 1024,
577
578
                    stop_sequences: vec![],
                },
Nicolas Patry's avatar
Nicolas Patry committed
579
                top_n_tokens: 0,
drbh's avatar
drbh committed
580
                adapter_id: None,
581
582
            },
            response_tx,
583
584
585
            span: info_span!("entry"),
            temp_span: None,
            queue_time: Instant::now(),
586
            batch_time: None,
587
            block_allocation: None,
588
589
        };
        (entry, receiver_tx)
590
591
    }

592
593
    #[tokio::test]
    async fn test_append() {
594
        let mut state = State::new(false, 1, false, None, 0, 16, false);
595
        let (entry, _guard) = default_entry();
596
597
598
599
600
601
602
603

        assert_eq!(state.next_id, 0);
        assert_eq!(state.entries.len(), 0);

        state.append(entry);

        assert_eq!(state.next_id, 1);
        assert_eq!(state.entries.len(), 1);
604
        let (id, _) = state.entries.remove(0).unwrap();
605
606
607
        assert_eq!(id, 0);
    }

608
609
    #[tokio::test]
    async fn test_next_batch_empty() {
610
        let mut state = State::new(false, 1, false, None, 0, 16, false);
611

612
613
        assert!(state.next_batch(None, None, 1, 1).await.is_none());
        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
614
615
    }

616
617
    #[tokio::test]
    async fn test_next_batch_min_size() {
618
        let mut state = State::new(false, 1, false, None, 0, 16, false);
619
620
621
622
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
623

624
        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
625
626
627
628
629
630
631
632
633
634
635
636
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 1);

637
638
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
639

640
        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
641
642
643

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 1);
644
        let (id, _) = state.entries.remove(0).unwrap();
645
646
647
        assert_eq!(id, 2);
    }

648
649
    #[tokio::test]
    async fn test_next_batch_max_size() {
650
        let mut state = State::new(false, 1, false, None, 0, 16, false);
651
652
653
654
655
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);

656
        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
657
658
659
660
661
662
663
664
665
666
667
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);
    }

668
669
    #[tokio::test]
    async fn test_next_batch_token_budget() {
670
        let mut state = State::new(false, 1, false, None, 0, 16, false);
671
672
673
674
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        state.append(entry1);
        state.append(entry2);
675

676
        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
677
678
679
680
681
682
683
684
685
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

        assert_eq!(state.next_id, 2);
        assert_eq!(state.entries.len(), 1);
        assert_eq!(state.next_batch_id, 1);

686
687
        let (entry3, _guard3) = default_entry();
        state.append(entry3);
688

689
        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
690
691
692
693
694
695
696
697
698
699
700
701
702
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);

        assert_eq!(state.next_id, 3);
        assert_eq!(state.entries.len(), 0);
        assert_eq!(state.next_batch_id, 2);
    }

    #[tokio::test]
    async fn test_queue_append() {
703
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
704
705
        let (entry, _guard) = default_entry();
        queue.append(entry);
706
707
708
709
    }

    #[tokio::test]
    async fn test_queue_next_batch_empty() {
710
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
711

712
713
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
714
715
716
717
    }

    #[tokio::test]
    async fn test_queue_next_batch_min_size() {
718
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
719
720
721
722
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
723

724
        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();
725
726
727
728
729
730
731
732
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert!(entries.get(&1).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);

733
734
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
735

736
        // Not enough requests pending
737
        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());
738
        // Not enough token budget
739
        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());
740
        // Ok
741
        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();
742
743
744
745
746
        assert_eq!(entries2.len(), 1);
        assert!(entries2.contains_key(&2));
        assert!(entries2.get(&2).unwrap().batch_time.is_some());
        assert_eq!(batch2.id, 1);
        assert_eq!(batch2.size, 1);
747
748
    }

749
750
    #[tokio::test]
    async fn test_queue_next_batch_max_size() {
751
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
752
753
754
755
756
757
758
759
760
761
762
763
764
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert!(entries.get(&0).unwrap().batch_time.is_some());
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);
    }

765
    #[tokio::test]
766
    async fn test_queue_next_batch_token_budget() {
767
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
768
769
770
771
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);
772

773
        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();
774
775
776
777
778
        assert_eq!(entries.len(), 1);
        assert!(entries.contains_key(&0));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 1);

779
780
        let (entry3, _guard3) = default_entry();
        queue.append(entry3);
781

782
        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();
783
784
785
786
787
788
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&1));
        assert!(entries.contains_key(&2));
        assert_eq!(batch.id, 1);
        assert_eq!(batch.size, 2);
    }
789

Nicolas Patry's avatar
Nicolas Patry committed
790
791
    #[tokio::test]
    async fn test_queue_next_batch_token_speculate() {
792
        let queue = Queue::new(true, 1, false, None, 2, 16, false);
Nicolas Patry's avatar
Nicolas Patry committed
793
794
795
796
797
798
        let (entry1, _guard1) = default_entry();
        let (entry2, _guard2) = default_entry();
        queue.append(entry1);
        queue.append(entry2);

        // Budget of 1 is not enough
799
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
Nicolas Patry's avatar
Nicolas Patry committed
800

801
        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();
Nicolas Patry's avatar
Nicolas Patry committed
802
803
804
805
806
807
808
        assert_eq!(entries.len(), 2);
        assert!(entries.contains_key(&0));
        assert!(entries.contains_key(&1));
        assert_eq!(batch.id, 0);
        assert_eq!(batch.size, 2);
    }

809
810
    #[tokio::test]
    async fn test_queue_next_batch_dropped_receiver() {
811
        let queue = Queue::new(false, 1, false, None, 0, 16, false);
812
813
814
        let (entry, _) = default_entry();
        queue.append(entry);

815
        assert!(queue.next_batch(None, None, 1, 1).await.is_none());
816
    }
817
}