prompt.rs 3.48 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-License-Identifier: Apache-2.0

//! Prompt Formatting Module
//!
//! Handles formatting of LLM request prompts, including:
//! - Chat template rendering
//! - Tool usage formatting
//! - Generation prompt handling
//!
//! The module supports different prompt formatting strategies through the
//! PromptFormatter

// TODO:
// 1. Query if `add_generation_prompt` is present in the prompt template
// 2. Support for models with add_generation_prompt:
//    - PALS (Prefix-Assisted Language Sampling)
//    - Continuation - Detected on user turns, where we can return
//      partial assistant responses without add_generation_prompt

use anyhow::Result;
use minijinja::value::Value;
23
use std::collections::HashMap;
Biswa Panda's avatar
Biswa Panda committed
24
25
use std::sync::Arc;

26
27
use crate::preprocessor::media::MediaDecoder;

28
pub mod deepseek_v32;
Biswa Panda's avatar
Biswa Panda committed
29
30
31
32
mod template;

pub use template::ContextMixins;

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#[derive(Debug)]
pub enum TokenInput {
    Single(Vec<u32>),
    Batch(Vec<Vec<u32>>),
}

#[derive(Debug)]
pub enum TextInput {
    Single(String),
    Batch(Vec<String>),
}

#[derive(Debug)]
pub enum PromptInput {
    Tokens(TokenInput),
    Text(TextInput),
}

Biswa Panda's avatar
Biswa Panda committed
51
52
/// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest {
53
    fn model(&self) -> String;
Biswa Panda's avatar
Biswa Panda committed
54
    fn messages(&self) -> Value;
55
56
57
58
59
    fn typed_messages(
        &self,
    ) -> Option<&[dynamo_async_openai::types::ChatCompletionRequestMessage]> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
60
61
62
63
64
65
66
67
    fn tools(&self) -> Option<Value> {
        None
    }
    fn tool_choice(&self) -> Option<Value> {
        None
    }

    fn should_add_generation_prompt(&self) -> bool;
68

69
70
71
72
73
    /// Optional additional args to merge into the chat template context
    fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
        None
    }

74
75
76
77
78
79
80
81
82
83
84
85
86
    /// Returns the type of input for the prompt. Default is Text.
    fn prompt_input_type(&self) -> PromptInput {
        PromptInput::Text(TextInput::Single(String::new()))
    }

    /// Extract tokens if the input is pre-tokenized
    fn extract_tokens(&self) -> Option<TokenInput> {
        None
    }

    fn extract_text(&self) -> Option<TextInput> {
        None
    }
87
88
89
90

    fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
91
92
93
94
95
96
97
98
99
100
}

pub trait OAIPromptFormatter: Send + Sync + 'static {
    fn supports_add_generation_prompt(&self) -> bool;
    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
}

pub enum PromptFormatter {
    OAI(Arc<dyn OAIPromptFormatter>),
}
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

// No-op formatter: used for models without chat_template
#[derive(Debug, Default)]
pub struct NoOpFormatter;

impl OAIPromptFormatter for NoOpFormatter {
    fn supports_add_generation_prompt(&self) -> bool {
        false
    }

    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
        let messages = req.messages();

        let first_message = messages
            .get_item_by_index(0)
            .map_err(|_| anyhow::Error::msg("No message at index 0 or messages array is empty"))?;

        let content = first_message
            .get_attr("content")
            .map_err(|_| anyhow::Error::msg("First message has no 'content' field"))?;

        let content_str = content
            .as_str()
            .ok_or_else(|| anyhow::Error::msg("Message content is not a string"))?
            .to_string();
        Ok(content_str)
    }
}

impl PromptFormatter {
    pub fn no_op() -> Self {
        Self::OAI(Arc::new(NoOpFormatter))
    }
}