prompt.rs 3.64 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-License-Identifier: Apache-2.0

//! Prompt Formatting Module
//!
//! Handles formatting of LLM request prompts, including:
//! - Chat template rendering
//! - Tool usage formatting
//! - Generation prompt handling
//!
//! The module supports different prompt formatting strategies through the
//! PromptFormatter

// TODO:
// 1. Query if `add_generation_prompt` is present in the prompt template
// 2. Support for models with add_generation_prompt:
//    - PALS (Prefix-Assisted Language Sampling)
//    - Continuation - Detected on user turns, where we can return
//      partial assistant responses without add_generation_prompt

use anyhow::Result;
use minijinja::value::Value;
23
use std::collections::HashMap;
Biswa Panda's avatar
Biswa Panda committed
24
25
use std::sync::Arc;

26
27
use crate::preprocessor::media::MediaDecoder;

28
pub mod deepseek_v32;
Biswa Panda's avatar
Biswa Panda committed
29
30
mod template;

31
pub use template::{ChatTemplate, ContextMixins};
Biswa Panda's avatar
Biswa Panda committed
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#[derive(Debug)]
pub enum TokenInput {
    Single(Vec<u32>),
    Batch(Vec<Vec<u32>>),
}

#[derive(Debug)]
pub enum TextInput {
    Single(String),
    Batch(Vec<String>),
}

#[derive(Debug)]
pub enum PromptInput {
    Tokens(TokenInput),
    Text(TextInput),
}

Biswa Panda's avatar
Biswa Panda committed
51
52
/// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest {
53
    fn model(&self) -> String;
Biswa Panda's avatar
Biswa Panda committed
54
    fn messages(&self) -> Value;
55
    fn typed_messages(&self) -> Option<&[dynamo_protocols::types::ChatCompletionRequestMessage]> {
56
57
        None
    }
Biswa Panda's avatar
Biswa Panda committed
58
59
60
61
62
63
    fn tools(&self) -> Option<Value> {
        None
    }
    fn tool_choice(&self) -> Option<Value> {
        None
    }
64
65
66
    fn response_format(&self) -> Option<Value> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
67
68

    fn should_add_generation_prompt(&self) -> bool;
69

70
71
72
73
74
    /// Optional additional args to merge into the chat template context
    fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
        None
    }

75
76
77
78
79
80
81
82
83
84
85
86
87
    /// Returns the type of input for the prompt. Default is Text.
    fn prompt_input_type(&self) -> PromptInput {
        PromptInput::Text(TextInput::Single(String::new()))
    }

    /// Extract tokens if the input is pre-tokenized
    fn extract_tokens(&self) -> Option<TokenInput> {
        None
    }

    fn extract_text(&self) -> Option<TextInput> {
        None
    }
88
89
90
91

    fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
        None
    }
92
93
94
95

    fn mm_processor_kwargs(&self) -> Option<&serde_json::Value> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
96
97
98
99
100
101
102
}

pub trait OAIPromptFormatter: Send + Sync + 'static {
    fn supports_add_generation_prompt(&self) -> bool;
    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
}

103
#[derive(Clone)]
Biswa Panda's avatar
Biswa Panda committed
104
105
106
pub enum PromptFormatter {
    OAI(Arc<dyn OAIPromptFormatter>),
}
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

// No-op formatter: used for models without chat_template
#[derive(Debug, Default)]
pub struct NoOpFormatter;

impl OAIPromptFormatter for NoOpFormatter {
    fn supports_add_generation_prompt(&self) -> bool {
        false
    }

    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
        let messages = req.messages();

        let first_message = messages
            .get_item_by_index(0)
            .map_err(|_| anyhow::Error::msg("No message at index 0 or messages array is empty"))?;

        let content = first_message
            .get_attr("content")
            .map_err(|_| anyhow::Error::msg("First message has no 'content' field"))?;

        let content_str = content
            .as_str()
            .ok_or_else(|| anyhow::Error::msg("Message content is not a string"))?
            .to_string();
        Ok(content_str)
    }
}

impl PromptFormatter {
    pub fn no_op() -> Self {
        Self::OAI(Arc::new(NoOpFormatter))
    }
}