infer.rs 70.5 KB
Newer Older
1
2
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
3
use crate::{
4
5
    ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
    HubTokenizerConfig, Message, PrefillToken, Queue, Token,
6
};
7
use futures::future::try_join_all;
8
use minijinja::{Environment, ErrorKind, Template};
9
use nohash_hasher::IntMap;
10
11
12
13
use std::sync::{
    atomic::{AtomicBool, Ordering},
    Arc,
};
14
use text_generation_client::{
Nicolas Patry's avatar
Nicolas Patry committed
15
    Batch, CachedBatch, ClientError, GeneratedText, Generation, ShardedClient, Tokens,
16
17
};
use thiserror::Error;
OlivierDehaene's avatar
OlivierDehaene committed
18
use tokio::sync::mpsc::error::SendError;
19
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
20
use tokio::time::Instant;
OlivierDehaene's avatar
OlivierDehaene committed
21
22
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
23
use tracing::{info_span, instrument, Instrument, Span};
24
25
26
27
28
29

/// Inference struct
#[derive(Clone)]
pub struct Infer {
    /// Validation
    validation: Validation,
30
31
    /// Request queue
    queue: Queue,
32
33
    /// Shared state
    shared: Arc<Shared>,
34
35
    /// Chat template
    chat_template: Option<ChatTemplate>,
36
37
38
39
40
41
42
43
44
45
    /// Inference limit
    limit_concurrent_requests: Arc<Semaphore>,
}

/// Infer shared state
struct Shared {
    /// Batching background Tokio task notifier
    batching_task: Notify,
}

46
47
48
49
50
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
    Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}

51
impl Infer {
52
    #[allow(clippy::too_many_arguments)]
53
54
55
    pub(crate) fn new(
        client: ShardedClient,
        validation: Validation,
56
        waiting_served_ratio: f32,
57
        max_batch_prefill_tokens: u32,
58
        max_batch_total_tokens: u32,
59
        max_waiting_tokens: usize,
60
        max_batch_size: Option<usize>,
61
        max_concurrent_requests: usize,
62
        requires_padding: bool,
63
        window_size: Option<u32>,
Nicolas Patry's avatar
Nicolas Patry committed
64
        speculate: u32,
65
        generation_health: Arc<AtomicBool>,
66
        tokenizer_config: HubTokenizerConfig,
67
68
    ) -> Self {
        // Infer shared state
Nicolas Patry's avatar
Nicolas Patry committed
69
        let queue = Queue::new(requires_padding, 16, window_size, speculate);
70
71
72
73
74
75
76
        let shared = Arc::new(Shared {
            batching_task: Notify::new(),
        });

        // Spawn batching background task that contains all the inference logic
        tokio::spawn(batching_task(
            client,
77
            waiting_served_ratio,
78
            max_batch_prefill_tokens,
79
            max_batch_total_tokens,
80
            max_waiting_tokens,
81
            max_batch_size,
82
            queue.clone(),
83
            shared.clone(),
84
            generation_health,
85
86
        ));

87
88
        let chat_template = tokenizer_config
            .chat_template
89
90
91
92
93
94
95
96
97
98
99
100
            .and_then(|t| match t {
                ChatTemplateVersions::Single(template) => Some(template),
                ChatTemplateVersions::Multiple(templates) => templates
                    .into_iter()
                    .find(|t| t.name == "default")
                    .map(|t| t.template),
            })
            .map(|t| {
                // .strip() is not supported in minijinja
                let t = t.replace(".strip()", " | trim");
                ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
            });
101

102
103
104
105
106
        // Inference limit with a semaphore
        let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

        Self {
            validation,
107
            queue,
108
            shared,
109
            chat_template,
110
111
112
113
            limit_concurrent_requests: semaphore,
        }
    }

114
    /// Add a new request to the queue and return a stream of InferStreamResponse
115
    #[instrument(skip_all)]
116
117
118
    pub(crate) async fn generate_stream(
        &self,
        request: GenerateRequest,
119
    ) -> Result<GenerateStreamResponse, InferError> {
120
        // Limit concurrent requests by acquiring a permit from the semaphore
121
122
123
124
125
        let permit = self
            .clone()
            .limit_concurrent_requests
            .try_acquire_owned()
            .map_err(|err| {
126
                metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
127
128
129
                tracing::error!("{err}");
                err
            })?;
130
131

        // Validate request
132
133
134
135
136
        let valid_request = self.validation.validate(request).await.map_err(|err| {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            err
        })?;
137
138

        // MPSC channel to communicate with the background batching task
OlivierDehaene's avatar
OlivierDehaene committed
139
        let (response_tx, response_rx) = mpsc::unbounded_channel();
140
        let input_length = valid_request.input_length;
141

142
143
        // Append the request to the queue
        self.queue.append(Entry {
144
145
            request: valid_request,
            response_tx,
146
147
148
            span: Span::current(),
            temp_span: None,
            queue_time: Instant::now(),
149
150
151
            batch_time: None,
        });

152
        // Notify the background task that we have a new entry in the queue that needs
153
154
155
156
        // to be batched
        self.shared.batching_task.notify_one();

        // Return stream
157
158
159
160
161
        Ok((
            permit,
            input_length,
            UnboundedReceiverStream::new(response_rx),
        ))
162
163
    }

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    /// Tokenizer the input
    #[instrument(skip_all)]
    pub(crate) async fn tokenize(
        &self,
        request: GenerateRequest,
    ) -> Result<Option<tokenizers::Encoding>, InferError> {
        // Tokenize request
        let inputs = request.inputs;
        let truncate = request.parameters.truncate;
        let encoding = self
            .validation
            .tokenize(inputs, truncate)
            .await
            .map_err(|err| {
                tracing::error!("Tokenization {err}");
                err
            })?;

        // Return Encoding
        Ok(encoding.map(|(encoding, _)| encoding))
    }

186
187
    /// Apply the chat template to the chat request
    #[instrument(skip_all)]
188
    pub(crate) fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String, InferError> {
189
        self.chat_template
190
191
            .as_ref()
            .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
192
            .apply(messages)
193
194
195
            .map_err(|e| {
                metrics::increment_counter!("tgi_request_failure", "err" => "template");
                tracing::error!("{e}");
196
                e
197
198
199
            })
    }

200
    /// Add a new request to the queue and return a InferResponse
201
    #[instrument(skip_all)]
202
203
204
205
    pub(crate) async fn generate(
        &self,
        request: GenerateRequest,
    ) -> Result<InferResponse, InferError> {
Nicolas Patry's avatar
Nicolas Patry committed
206
207
        let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);

208
        // Create stream and keep semaphore permit as long as generate lives
209
        let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
210
211
212
213

        // Return values
        let mut result_prefill = Vec::new();
        let mut result_tokens = Vec::new();
Nicolas Patry's avatar
Nicolas Patry committed
214
        let mut result_top_tokens = Vec::new();
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        let mut result_generated_text = None;
        let mut result_start = None;
        let mut result_queued = None;

        // Iterate on stream
        while let Some(response) = stream.next().await {
            match response? {
                // Add prefill tokens
                InferStreamResponse::Prefill(tokens) => {
                    // Create Token objects
                    // We do that here instead of in the Python code as Rust for loops are faster
                    result_prefill = tokens
                        .ids
                        .into_iter()
                        .zip(tokens.logprobs.into_iter())
                        .zip(tokens.texts.into_iter())
231
                        .map(|((id, logprob), text)| PrefillToken { id, text, logprob })
232
233
234
                        .collect();
                }
                // Push last token
Nicolas Patry's avatar
Nicolas Patry committed
235
236
237
238
                InferStreamResponse::Intermediate { token, top_tokens } => {
                    result_tokens.push(token);
                    result_top_tokens.push(top_tokens);
                }
239
240
241
242
243
244
245
                // Final message
                // Set return values
                InferStreamResponse::End {
                    token,
                    generated_text,
                    start,
                    queued,
Nicolas Patry's avatar
Nicolas Patry committed
246
                    top_tokens,
247
248
                } => {
                    result_tokens.push(token);
Nicolas Patry's avatar
Nicolas Patry committed
249
                    result_top_tokens.push(top_tokens);
250
251
252
253
254
255
256
257
258
259
260
261
262
                    result_generated_text = Some(generated_text);
                    result_start = Some(start);
                    result_queued = Some(queued)
                }
            }
        }

        // Check that we received a `InferStreamResponse::End` message
        if let (Some(generated_text), Some(queued), Some(start)) =
            (result_generated_text, result_queued, result_start)
        {
            Ok(InferResponse {
                prefill: result_prefill,
263
                _input_length,
264
265
266
267
                tokens: result_tokens,
                generated_text,
                queued,
                start,
Nicolas Patry's avatar
Nicolas Patry committed
268
269
270
271
272
                top_tokens: if use_top_tokens {
                    result_top_tokens
                } else {
                    Vec::new()
                },
273
274
            })
        } else {
275
            let err = InferError::IncompleteGeneration;
276
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
277
278
            tracing::error!("{err}");
            Err(err)
279
280
        }
    }
281
282
    /// Add best_of new requests to the queue and return a InferResponse of the sequence with
    /// the highest log probability per token
283
    #[instrument(skip(self, request))]
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    pub(crate) async fn generate_best_of(
        &self,
        request: GenerateRequest,
        best_of: usize,
    ) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
        // validate  best_of parameter separately
        let best_of = self.validation.validate_best_of(best_of)?;

        // create multiple generate requests
        let mut infer_responses: Vec<InferResponse> =
            try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;

        // get the sequence with the highest log probability per token
        let mut max_index = 0;
        let mut max_logprob: f32 = f32::MIN;

        for (i, response) in infer_responses.iter().enumerate() {
            // mean logprobs of the generated tokens
            let sequence_logprob = response
                .tokens
                .iter()
                .map(|token| token.logprob)
                .sum::<f32>()
                / response.tokens.len() as f32;

            // set best sequence
            if sequence_logprob > max_logprob {
                max_index = i;
                max_logprob = sequence_logprob;
            }
        }
        let best_response = infer_responses.remove(max_index);
        Ok((best_response, infer_responses))
    }
318
319
}

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#[derive(Clone)]
struct ChatTemplate {
    template: Template<'static, 'static>,
    bos_token: Option<String>,
    eos_token: Option<String>,
}

impl ChatTemplate {
    fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
        let mut env = Box::new(Environment::new());
        let template_str = template.into_boxed_str();
        env.add_function("raise_exception", raise_exception);
        // leaking env and template_str as read-only, static resources for performance.
        let template = Box::leak(env)
            .template_from_str(Box::leak(template_str))
            .unwrap();

        Self {
            template,
            bos_token,
            eos_token,
        }
    }

    fn apply(&self, messages: Vec<Message>) -> Result<String, InferError> {
        self.template
            .render(ChatTemplateInputs {
                messages,
                bos_token: self.bos_token.as_deref(),
                eos_token: self.eos_token.as_deref(),
                add_generation_prompt: true,
            })
            .map_err(InferError::TemplateError)
    }
}

356
357
358
359
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
360
#[allow(clippy::too_many_arguments)]
361
362
async fn batching_task(
    mut client: ShardedClient,
363
    waiting_served_ratio: f32,
364
    max_batch_prefill_tokens: u32,
365
    max_batch_total_tokens: u32,
366
    max_waiting_tokens: usize,
367
    max_batch_size: Option<usize>,
368
    queue: Queue,
369
    shared: Arc<Shared>,
370
    generation_health: Arc<AtomicBool>,
371
372
373
374
375
376
) {
    // Infinite loop
    loop {
        // Wait for a notification from the Infer struct
        shared.batching_task.notified().await;

377
        // Get the next batch from the queue
378
        // This batch might be smaller than the maximum batch size if there are not enough requests
379
        // waiting in the queue
380
        while let Some((mut entries, batch, span)) = queue
381
382
383
384
385
386
            .next_batch(
                None,
                max_batch_size,
                max_batch_prefill_tokens,
                max_batch_total_tokens,
            )
387
            .await
388
        {
389
            let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
390
391
                .instrument(span)
                .await;
392
393
394
395
396
397
398
            let mut waiting_tokens = 1;

            // We loop until we do not receive any cached batch from the inference server (== until
            // all requests have met their stopping criteria)
            while let Some(batch) = cached_batch {
                // Get current batch info
                let batch_size = batch.size;
399
                let batch_max_tokens = batch.max_tokens;
400
                let mut batches = vec![batch];
401
                metrics::gauge!("tgi_batch_current_size", batch_size as f64);
402
403
404
405
406
407
408
409
410
411
412
                metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);

                let min_size = if waiting_tokens >= max_waiting_tokens {
                    // If we didn't onboard any new requests since >= max_waiting_tokens, we try
                    // to add a new batch even though its size might be small
                    None
                } else {
                    // Minimum batch size
                    Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
                };

413
                let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
414
                let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
415
416

                // Try to get a new batch
417
                if let Some((mut new_entries, new_batch, span)) = queue
418
                    .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
419
                    .await
420
421
422
423
424
425
426
                {
                    // Tracking metrics
                    if min_size.is_some() {
                        metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
                    } else {
                        metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
                    }
427

428
429
430
431
432
433
434
435
436
437
438
439
                    entries.iter_mut().for_each(|(_, entry)| {
                        // Create a new span to add the info that this entry is waiting
                        // because a new batch is being computed
                        let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
                        // Add relationships
                        span.follows_from(&entry_waiting_span);
                        entry_waiting_span.follows_from(&span);
                        // Update entry
                        entry.temp_span = Some(entry_waiting_span);
                    });

                    // Generate one token for this new batch to have the attention past in cache
440
441
442
443
                    let new_cached_batch =
                        prefill(&mut client, new_batch, &mut new_entries, &generation_health)
                            .instrument(span)
                            .await;
444
445
446
447
448
449
                    // Reset waiting counter
                    waiting_tokens = 1;
                    // Extend current batch with the new batch
                    if let Some(new_cached_batch) = new_cached_batch {
                        entries.extend(new_entries);
                        batches.push(new_cached_batch);
450
451
                    }
                }
452

453
454
455
456
457
458
                // Create span for this batch to add context to inference calls
                let next_batch_size = entries.len();
                let next_batch_span =
                    info_span!(parent: None, "batch", batch_size = next_batch_size);
                entries.iter_mut().for_each(|(_, entry)| {
                    // Create a new span to link the batch back to this entry
459
                    let entry_batch_span = info_span!(parent: &entry.span, "infer");
460
461
                    // Add relationships
                    next_batch_span.follows_from(&entry_batch_span);
462
463
464
465
                    entry_batch_span.follows_from(&next_batch_span);
                    // Update entry
                    entry.temp_span = Some(entry_batch_span);
                });
466

467
                cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
468
469
                    .instrument(next_batch_span)
                    .await;
470
471
                waiting_tokens += 1;
            }
472
            metrics::gauge!("tgi_batch_current_size", 0.0);
473
            metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
474
475
476
477
        }
    }
}

478
#[instrument(skip_all)]
479
480
481
async fn prefill(
    client: &mut ShardedClient,
    batch: Batch,
482
    entries: &mut IntMap<u64, Entry>,
483
    generation_health: &Arc<AtomicBool>,
484
) -> Option<CachedBatch> {
485
    let start_time = Instant::now();
486
    let batch_id = batch.id;
487
    metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
488
489

    match client.prefill(batch).await {
490
        Ok((generations, next_batch, timings)) => {
491
492
            // Update health
            generation_health.store(true, Ordering::SeqCst);
493
494

            let start_filtering_time = Instant::now();
495
            // Send generated tokens and filter stopped entries
496
497
498
            filter_send_generations(generations, entries);

            // Filter next batch and remove requests that were stopped
499
            let next_batch = filter_batch(client, next_batch, entries).await;
500

501
502
503
            metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
            metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
            metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
504
            metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
505
506
507
508
509
            metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
            next_batch
        }
        // If we have an error, we discard the whole batch
        Err(err) => {
510
511
            // Update health
            generation_health.store(false, Ordering::SeqCst);
512
            let _ = client.clear_cache(Some(batch_id)).await;
513
514
515
516
517
518
519
520
521
522
            send_errors(err, entries);
            metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
            None
        }
    }
}

#[instrument(skip_all)]
async fn decode(
    client: &mut ShardedClient,
523
    batches: Vec<CachedBatch>,
524
    entries: &mut IntMap<u64, Entry>,
525
    generation_health: &Arc<AtomicBool>,
526
) -> Option<CachedBatch> {
527
    let start_time = Instant::now();
528
    let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
529
    metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
530
531

    match client.decode(batches).await {
532
        Ok((generations, next_batch, timings)) => {
533
534
            // Update health
            generation_health.store(true, Ordering::SeqCst);
535
536

            let start_filtering_time = Instant::now();
537
            // Send generated tokens and filter stopped entries
538
539
540
            filter_send_generations(generations, entries);

            // Filter next batch and remove requests that were stopped
541
            let next_batch = filter_batch(client, next_batch, entries).await;
542

543
544
545
546
547
548
            if let Some(concat_duration) = timings.concat {
                metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
            }
            metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
            metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
            metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
549
            metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
550
            metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
551
552
553
554
            next_batch
        }
        // If we have an error, we discard the whole batch
        Err(err) => {
555
            generation_health.store(false, Ordering::SeqCst);
556
557
558
            for id in batch_ids {
                let _ = client.clear_cache(Some(id)).await;
            }
559
            send_errors(err, entries);
560
            metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
561
562
563
564
565
            None
        }
    }
}

566
567
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
568
569
async fn filter_batch(
    client: &mut ShardedClient,
570
    next_batch: Option<CachedBatch>,
571
    entries: &IntMap<u64, Entry>,
572
) -> Option<CachedBatch> {
573
574
575
576
577
578
579
580
581
582
    let mut batch = next_batch?;

    // No need to filter
    if batch.size as usize == entries.len() {
        return Some(batch);
    }

    let id = batch.id;

    // Retain only requests that are still in entries
583
    batch.request_ids.retain(|id| entries.contains_key(id));
584

585
    if batch.request_ids.is_empty() {
586
587
588
589
590
591
592
593
594
        // All requests have been filtered out
        // Next batch is now empty
        // Clear it from the Python shards cache
        // We unwrap here as we need to panic since we cannot recover if this method fails
        client.clear_cache(Some(id)).await.unwrap();
        None
    } else {
        // Filter Python shard cache
        // We unwrap here as we need to panic since we cannot recover if this method fails
595
        client.filter_batch(id, batch.request_ids).await.unwrap()
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
    }
}

/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
    generations.into_iter().for_each(|generation| {
        let id = generation.request_id;
        // Get entry
        // We can `expect` here as the request id should always be in the entries
        let entry = entries
            .get(&id)
            .expect("ID not found in entries. This is a bug.");

        // Create and enter a span to link this function back to the entry
        let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
        // Send generation responses back to the infer task
        // If the receive an error from the Flume channel, it means that the client dropped the
        // request and we need to stop generating hence why we unwrap_or(true)
        let stopped = send_responses(generation, entry).map_err(|err| {
OlivierDehaene's avatar
OlivierDehaene committed
617
            tracing::error!("Entry response channel error.");
618
619
620
621
622
623
624
625
626
627
628
629
630
            metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
            err
        }).unwrap_or(true);
        if stopped {
            entries.remove(&id).expect("ID not found in entries. This is a bug.");
        }
    });
}

/// Send responses through the `entry` response channel
fn send_responses(
    generation: Generation,
    entry: &Entry,
OlivierDehaene's avatar
OlivierDehaene committed
631
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
632
    // Return directly if the channel is disconnected
OlivierDehaene's avatar
OlivierDehaene committed
633
634
    if entry.response_tx.is_closed() {
        metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
635
636
637
        return Ok(true);
    }

638
639
640
641
    let mut stopped = false;

    if let Some(prefill_tokens) = generation.prefill_tokens {
        // Send message
OlivierDehaene's avatar
OlivierDehaene committed
642
643
644
        entry
            .response_tx
            .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
645
646
647
    }

    // Create last Token
Nicolas Patry's avatar
Nicolas Patry committed
648
649
650
651
652
653
    let tokens_ = generation.tokens.expect("Non empty tokens in generation");
    let n = tokens_.ids.len();
    metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
    let mut iterator = tokens_
        .ids
        .into_iter()
654
655
656
        .zip(tokens_.logprobs)
        .zip(tokens_.texts)
        .zip(tokens_.is_special)
Nicolas Patry's avatar
Nicolas Patry committed
657
658
659
660
661
662
663
664
665
666
        .enumerate()
        .peekable();
    while let Some((i, (((id, logprob), text), special))) = iterator.next() {
        let token = Token {
            id,
            text,
            logprob,
            special,
        };
        let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
Nicolas Patry's avatar
Nicolas Patry committed
667
668
            top_tokens_
                .ids
Nicolas Patry's avatar
Nicolas Patry committed
669
670
671
672
673
                .iter()
                .zip(top_tokens_.logprobs.iter())
                .zip(top_tokens_.texts.iter())
                .zip(top_tokens_.is_special.iter())
                .map(|(((&id, &logprob), text), &special)| Token {
Nicolas Patry's avatar
Nicolas Patry committed
674
                    id,
Nicolas Patry's avatar
Nicolas Patry committed
675
                    text: text.to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
676
677
                    logprob,
                    special,
Nicolas Patry's avatar
Nicolas Patry committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
                })
                .collect()
        } else {
            vec![]
        };
        match (&generation.generated_text, iterator.peek()) {
            (Some(generated_text), None) => {
                // Generation has ended
                stopped = true;
                // Send message
                entry.response_tx.send(Ok(InferStreamResponse::End {
                    token,
                    top_tokens,
                    generated_text: generated_text.clone(),
                    queued: entry.queue_time,
                    start: entry.batch_time.unwrap(),
                }))?;
            }
            _ => {
                // Send message
                entry
                    .response_tx
                    .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
            }
        }
Nicolas Patry's avatar
Nicolas Patry committed
703
704
    }

705
706
707
    Ok(stopped)
}

708
/// Send errors to Infer for all `entries`
709
710
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
711
    entries.drain().for_each(|(_, entry)| {
712
713
714
        // Create and enter a span to link this function back to the entry
        let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
        let err = InferError::GenerationError(error.to_string());
715
        metrics::increment_counter!("tgi_request_failure", "err" => "generation");
716
717
        tracing::error!("{err}");

718
719
720
        // unwrap_or is valid here as we don't care if the receiver is gone.
        entry
            .response_tx
OlivierDehaene's avatar
OlivierDehaene committed
721
            .send(Err(err))
722
723
724
725
726
727
728
            .unwrap_or(());
    });
}

#[derive(Debug)]
pub(crate) enum InferStreamResponse {
    // Optional first message
Nicolas Patry's avatar
Nicolas Patry committed
729
    Prefill(Tokens),
730
    // Intermediate messages
Nicolas Patry's avatar
Nicolas Patry committed
731
732
733
734
    Intermediate {
        token: Token,
        top_tokens: Vec<Token>,
    },
735
736
737
    // Last message
    End {
        token: Token,
Nicolas Patry's avatar
Nicolas Patry committed
738
        top_tokens: Vec<Token>,
739
740
741
742
743
744
745
746
        generated_text: GeneratedText,
        start: Instant,
        queued: Instant,
    },
}

#[derive(Debug)]
pub(crate) struct InferResponse {
747
748
749
750
    /// input_length is the input as perceived by the rust tokenizer in the
    /// validation pathway. It is redundant with prefill.len() but prefill
    /// has data only if the user asked for it. This will always be filled.
    pub(crate) _input_length: u32,
751
    pub(crate) prefill: Vec<PrefillToken>,
752
753
754
755
    pub(crate) tokens: Vec<Token>,
    pub(crate) generated_text: GeneratedText,
    pub(crate) queued: Instant,
    pub(crate) start: Instant,
Nicolas Patry's avatar
Nicolas Patry committed
756
    pub(crate) top_tokens: Vec<Vec<Token>>,
757
758
759
760
761
762
763
764
765
766
767
768
}

#[derive(Debug, Error)]
pub enum InferError {
    #[error("Request failed during generation: {0}")]
    GenerationError(String),
    #[error("Model is overloaded")]
    Overloaded(#[from] TryAcquireError),
    #[error("Input validation error: {0}")]
    ValidationError(#[from] ValidationError),
    #[error("Incomplete generation")]
    IncompleteGeneration,
769
770
    #[error("Template error: {0}")]
    TemplateError(#[from] minijinja::Error),
771
}
772
773
774
775
776
777
778
779

impl InferError {
    pub(crate) fn error_type(&self) -> &str {
        match self {
            InferError::GenerationError(_) => "generation",
            InferError::Overloaded(_) => "overloaded",
            InferError::ValidationError(_) => "validation",
            InferError::IncompleteGeneration => "incomplete_generation",
780
            InferError::TemplateError(_) => "template_error",
781
782
783
        }
    }
}
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825

// tests
#[cfg(test)]
mod tests {
    use crate::infer::raise_exception;
    use crate::ChatTemplateInputs;
    use crate::Message;
    use minijinja::Environment;

    #[test]
    fn test_chat_template() {
        let env = Environment::new();

        let source = r#"
        {% for message in messages %}
            {% if message['role'] == 'system' %}
                {% if message['content']%}
                    {{'### System:\n' + message['content']+'\n\n'}}
                {% endif %}
            {% elif message['role'] == 'user' %}
                {{'### User:\n' + message['content']+'\n\n'}}
            {% elif message['role'] == 'assistant' %}
                {{'### Assistant:\n'  + message['content']}}
            {% endif %}
            {% if loop.last and add_generation_prompt %}
                {{ '### Assistant:\n' }}
            {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
826
                    content: Some("Hi!".to_string()),
827
                    name: None,
drbh's avatar
drbh committed
828
                    tool_calls: None,
829
830
831
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
832
                    content: Some("Hello how can I help?".to_string()),
833
                    name: None,
drbh's avatar
drbh committed
834
                    tool_calls: None,
835
836
837
                },
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
838
                    content: Some("What is Deep Learning?".to_string()),
839
                    name: None,
drbh's avatar
drbh committed
840
                    tool_calls: None,
841
842
843
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
844
                    content: Some("magic!".to_string()),
845
                    name: None,
drbh's avatar
drbh committed
846
                    tool_calls: None,
847
848
849
850
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
851
            add_generation_prompt: true,
852
853
854
855
856
857
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();

        assert_eq!(
            result,
858
            "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
        );
    }

    #[test]
    fn test_chat_template_invalid_with_raise() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {{ bos_token }}
        {% for message in messages %}
        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
        {% endif %}
        {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] + ' [/INST]' }}
        {% elif message['role'] == 'assistant' %}
        {{ message['content'] + eos_token}}
        {% else %}
        {{ raise_exception('Only user and assistant roles are supported!') }}
        {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
895
                    content: Some("Hi!".to_string()),
896
                    name: None,
drbh's avatar
drbh committed
897
                    tool_calls: None,
898
899
900
                },
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
901
                    content: Some("Hi again!".to_string()),
902
                    name: None,
drbh's avatar
drbh committed
903
                    tool_calls: None,
904
905
906
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
907
                    content: Some("Hello how can I help?".to_string()),
908
                    name: None,
drbh's avatar
drbh committed
909
                    tool_calls: None,
910
911
912
                },
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
913
                    content: Some("What is Deep Learning?".to_string()),
914
                    name: None,
drbh's avatar
drbh committed
915
                    tool_calls: None,
916
917
918
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
919
                    content: Some("magic!".to_string()),
920
                    name: None,
drbh's avatar
drbh committed
921
                    tool_calls: None,
922
923
924
925
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
926
            add_generation_prompt: true,
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
        };

        let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();

        match result {
            Ok(_) => panic!("Should have failed"),
            Err(e) => {
                assert_eq!(
                    e.detail().unwrap(),
                    "Conversation roles must alternate user/assistant/user/assistant/..."
                );
            }
        }
    }

    #[test]
    fn test_chat_template_valid_with_raise() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {{ bos_token }}
        {% for message in messages %}
        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
        {% endif %}
        {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] + ' [/INST]' }}
        {% elif message['role'] == 'assistant' %}
        {{ message['content'] + eos_token}}
        {% else %}
        {{ raise_exception('Only user and assistant roles are supported!') }}
        {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
975
                    content: Some("Hi!".to_string()),
976
                    name: None,
drbh's avatar
drbh committed
977
                    tool_calls: None,
978
979
980
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
981
                    content: Some("Hello how can I help?".to_string()),
982
                    name: None,
drbh's avatar
drbh committed
983
                    tool_calls: None,
984
985
986
                },
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
987
                    content: Some("What is Deep Learning?".to_string()),
988
                    name: None,
drbh's avatar
drbh committed
989
                    tool_calls: None,
990
991
992
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
993
                    content: Some("magic!".to_string()),
994
                    name: None,
drbh's avatar
drbh committed
995
                    tool_calls: None,
996
997
998
999
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
1000
            add_generation_prompt: true,
1001
1002
1003
1004
1005
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
        assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
    }
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032

    #[test]
    fn test_chat_template_valid_with_add_generation_prompt() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {% for message in messages %}
        {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
        {% endfor %}
        {% if add_generation_prompt %}
            {{ '<|im_start|>assistant\n' }}
        {% endif %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
1033
                    content: Some("Hi!".to_string()),
1034
                    name: None,
drbh's avatar
drbh committed
1035
                    tool_calls: None,
1036
1037
1038
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
1039
                    content: Some("Hello how can I help?".to_string()),
1040
                    name: None,
drbh's avatar
drbh committed
1041
                    tool_calls: None,
1042
1043
1044
                },
                Message {
                    role: "user".to_string(),
drbh's avatar
drbh committed
1045
                    content: Some("What is Deep Learning?".to_string()),
1046
                    name: None,
drbh's avatar
drbh committed
1047
                    tool_calls: None,
1048
1049
1050
                },
                Message {
                    role: "assistant".to_string(),
drbh's avatar
drbh committed
1051
                    content: Some("magic!".to_string()),
1052
                    name: None,
drbh's avatar
drbh committed
1053
                    tool_calls: None,
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
            add_generation_prompt: true,
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
        assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
    }
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

    struct ChatTemplateTestItem {
        name: &'static str,
        chat_template: &'static str,
        input: ChatTemplateInputs<'static>,
        target: &'static str,
    }

    #[test]
    fn test_many_chat_templates() {
        let example_chat = vec![
            Message {
                role: "user".to_string(),
                content: Some("Hello, how are you?".to_string()),
                name: None,
                tool_calls: None,
            },
            Message {
                role: "assistant".to_string(),
                content: Some("I'm doing great. How can I help you today?".to_string()),
                name: None,
                tool_calls: None,
            },
            Message {
                role: "user".to_string(),
                content: Some("I'd like to show off how chat templating works!".to_string()),
                name: None,
                tool_calls: None,
            },
        ];

        let example_chat_with_system = vec![Message {
            role: "system".to_string(),
            content: Some(
                "You are a friendly chatbot who always responds in the style of a pirate"
                    .to_string(),
            ),
            name: None,
            tool_calls: None,
        }]
        .iter()
        .chain(&example_chat)
        .cloned()
        .collect::<Vec<_>>();

        let test_default_templates = vec![
            ChatTemplateTestItem {
                name: "_base",
                chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
1113
                input: ChatTemplateInputs {
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some(""),
                },
                target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "blenderbot",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}",
1124
                input: ChatTemplateInputs {
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
                },
                target: " Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "blenderbot_small",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}",
1135
                input: ChatTemplateInputs {
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
                },
                target: " Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "bloom",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
1146
                input: ChatTemplateInputs {
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
                },
                target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "gpt_neox",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
1157
                input: ChatTemplateInputs {
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
                },
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
            },
            ChatTemplateTestItem {
                name: "gpt2",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
1168
1169
1170
1171
1172
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
1173
                },
1174
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
1175
1176
1177
1178
1179
            },
            ChatTemplateTestItem {
                name: "llama",
                // NOTE: the `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
1180
1181
1182
1183
1184
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: true,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
1185
                },
1186
                target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
1187
1188
1189
1190
            },
            ChatTemplateTestItem {
                name: "whisper",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
1191
1192
1193
1194
1195
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: true,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
1196
                },
1197
1198
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
            },
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        ];

        #[allow(unused_variables)] // name is unused
        for ChatTemplateTestItem {
            name,
            chat_template,
            input,
            target,
        } in test_default_templates
        {
            let mut env = Environment::new();
            env.add_function("raise_exception", raise_exception);
            let tmpl = env.template_from_str(&chat_template);
            let result = tmpl.unwrap().render(input).unwrap();
            assert_eq!(result, target);
        }

        let test_custom_templates = vec![
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
1224
                    eos_token: Some("</s>"),
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                },
                target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: vec![
                        Message {
                            role: "system".to_string(),
                            content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
                            name: None,
                            tool_calls: None,
                        },
                        Message {
                            role: "user".to_string(),
                            content: Some("How many helicopters can a human eat in one sitting?".to_string()),
                            name: None,
                            tool_calls: None,
                        },
                    ],
                    add_generation_prompt: true,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
                },
1250
                target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            },
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
                chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<bos>"),
                    eos_token: Some("<eos>"),
                },
                target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "mistralai/Mistral-7B-Instruct-v0.1",
                chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
1272
                target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
            },
            ChatTemplateTestItem {
                name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
                chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
            },
            ChatTemplateTestItem {
                name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
                chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
                input: ChatTemplateInputs {
1289
                    messages: example_chat.clone(),
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "openchat/openchat-3.5-0106",
                // `.title()` has been replaced with `| upper` in the following template
                chat_template: "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
            },
            ChatTemplateTestItem {
                name: "upstage/SOLAR-10.7B-Instruct-v1.0",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "codellama/CodeLlama-70b-Instruct-hf",
                // NOTE: `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\n\\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\\nDestination: user\\n\\n '}}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
            },
            ChatTemplateTestItem {
                name: "Deci/DeciLM-7B-instruct",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\\n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "Qwen/Qwen1.5-72B-Chat",
                chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "deepseek-ai/deepseek-llm-7b-chat",
                chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("<|end▁of▁sentence|>"),
                },
                target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n",
            },
            ChatTemplateTestItem {
                name: "h2oai/h2o-danube-1.8b-chat",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>'  + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
1373
                target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
            },
            ChatTemplateTestItem {
                name: "internlm/internlm2-chat-7b",
                chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "TheBloke/deepseek-coder-33B-instruct-AWQ",
                chat_template: "{%- set found_item = false -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set found_item = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("<|EOT|>"),
                },
                target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
            },
            ChatTemplateTestItem {
                name: "ericzzz/falcon-rw-1b-chat",
                // `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] '  + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|endoftext|>"),
                    eos_token: Some("<|endoftext|>"),
                },
                target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "abacusai/Smaug-34B-v0.1",
                chat_template: "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
            },
            ChatTemplateTestItem {
                name: "maywell/Synatra-Mixtral-8x7B",
                chat_template: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "deepseek-ai/deepseek-coder-33b-instruct",
                chat_template: "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set ns.found = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("</EOT>"),
                },
                target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
            },
            // NOT INCLUDED
            // - meetkai/functionary-medium-v2.2
            // - fireworks-ai/firefunction-v1
            // https://github
            ChatTemplateTestItem {
                name: "maywell/PiVoT-MoE",
                chat_template: "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                },
                target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
1456
            },
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        ];

        #[allow(unused_variables)] // name is unused
        for ChatTemplateTestItem {
            name,
            chat_template,
            input,
            target,
        } in test_custom_templates
        {
            let mut env = Environment::new();
            env.add_function("raise_exception", raise_exception);
            // trim all the whitespace
            let chat_template = chat_template
                .lines()
                .map(|line| line.trim())
                .collect::<Vec<&str>>()
                .join("");

            let tmpl = env.template_from_str(&chat_template);
            let result = tmpl.unwrap().render(input).unwrap();
            assert_eq!(result, target);
        }
    }
1481
}