prompt.rs 3.66 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-License-Identifier: Apache-2.0

//! Prompt Formatting Module
//!
//! Handles formatting of LLM request prompts, including:
//! - Chat template rendering
//! - Tool usage formatting
//! - Generation prompt handling
//!
//! The module supports different prompt formatting strategies through the
//! PromptFormatter

// TODO:
// 1. Query if `add_generation_prompt` is present in the prompt template
// 2. Support for models with add_generation_prompt:
//    - PALS (Prefix-Assisted Language Sampling)
//    - Continuation - Detected on user turns, where we can return
//      partial assistant responses without add_generation_prompt

use anyhow::Result;
use minijinja::value::Value;
23
use std::collections::HashMap;
Biswa Panda's avatar
Biswa Panda committed
24
25
use std::sync::Arc;

26
27
use crate::preprocessor::media::MediaDecoder;

28
pub mod deepseek_v32;
29
pub mod deepseek_v4;
Biswa Panda's avatar
Biswa Panda committed
30
31
mod template;

32
pub use template::{ChatTemplate, ContextMixins};
Biswa Panda's avatar
Biswa Panda committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#[derive(Debug)]
pub enum TokenInput {
    Single(Vec<u32>),
    Batch(Vec<Vec<u32>>),
}

#[derive(Debug)]
pub enum TextInput {
    Single(String),
    Batch(Vec<String>),
}

#[derive(Debug)]
pub enum PromptInput {
    Tokens(TokenInput),
    Text(TextInput),
}

Biswa Panda's avatar
Biswa Panda committed
52
53
/// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest {
54
    fn model(&self) -> String;
Biswa Panda's avatar
Biswa Panda committed
55
    fn messages(&self) -> Value;
56
    fn typed_messages(&self) -> Option<&[dynamo_protocols::types::ChatCompletionRequestMessage]> {
57
58
        None
    }
Biswa Panda's avatar
Biswa Panda committed
59
60
61
62
63
64
    fn tools(&self) -> Option<Value> {
        None
    }
    fn tool_choice(&self) -> Option<Value> {
        None
    }
65
66
67
    fn response_format(&self) -> Option<Value> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
68
69

    fn should_add_generation_prompt(&self) -> bool;
70

71
72
73
74
75
    /// Optional additional args to merge into the chat template context
    fn chat_template_args(&self) -> Option<&HashMap<String, serde_json::Value>> {
        None
    }

76
77
78
79
80
81
82
83
84
85
86
87
88
    /// Returns the type of input for the prompt. Default is Text.
    fn prompt_input_type(&self) -> PromptInput {
        PromptInput::Text(TextInput::Single(String::new()))
    }

    /// Extract tokens if the input is pre-tokenized
    fn extract_tokens(&self) -> Option<TokenInput> {
        None
    }

    fn extract_text(&self) -> Option<TextInput> {
        None
    }
89
90
91
92

    fn media_io_kwargs(&self) -> Option<&MediaDecoder> {
        None
    }
93
94
95
96

    fn mm_processor_kwargs(&self) -> Option<&serde_json::Value> {
        None
    }
Biswa Panda's avatar
Biswa Panda committed
97
98
99
100
101
102
103
}

pub trait OAIPromptFormatter: Send + Sync + 'static {
    fn supports_add_generation_prompt(&self) -> bool;
    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String>;
}

104
#[derive(Clone)]
Biswa Panda's avatar
Biswa Panda committed
105
106
107
pub enum PromptFormatter {
    OAI(Arc<dyn OAIPromptFormatter>),
}
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

// No-op formatter: used for models without chat_template
#[derive(Debug, Default)]
pub struct NoOpFormatter;

impl OAIPromptFormatter for NoOpFormatter {
    fn supports_add_generation_prompt(&self) -> bool {
        false
    }

    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
        let messages = req.messages();

        let first_message = messages
            .get_item_by_index(0)
            .map_err(|_| anyhow::Error::msg("No message at index 0 or messages array is empty"))?;

        let content = first_message
            .get_attr("content")
            .map_err(|_| anyhow::Error::msg("First message has no 'content' field"))?;

        let content_str = content
            .as_str()
            .ok_or_else(|| anyhow::Error::msg("Message content is not a string"))?
            .to_string();
        Ok(content_str)
    }
}

impl PromptFormatter {
    pub fn no_op() -> Self {
        Self::OAI(Arc::new(NoOpFormatter))
    }
}