"lib/llm/src/gguf/gguf_tokenizer.rs" did not exist on "d29f7fcc820e4eca241430fa3e4cfd9edd172097"
config.rs 2.46 KB
Newer Older
Graham King's avatar
Graham King committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, Default, Builder)]
pub struct ExecutorConfig {
    model_path: String,

    #[builder(default = "LogLevel::Error")]
    log_level: LogLevel,

    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default)]
    enable_chunked_context: Option<bool>,

    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default)]
    normalize_log_probs: Option<bool>,

    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default)]
    iter_stats_max_iterations: Option<u32>,

    /// The number of processes for tensor parallelism. Defaults to 1.
    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default)]
    tensor_parallel_size: Option<u32>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum LogLevel {
    #[default]
    Error,
    Warn,
    Info,
    Debug,
    Trace,
}

impl From<&str> for LogLevel {
    fn from(value: &str) -> Self {
        match value.to_lowercase().as_str() {
            "error" => LogLevel::Error,
            "warn" => LogLevel::Warn,
            "info" => LogLevel::Info,
            "debug" => LogLevel::Debug,
            "trace" => LogLevel::Trace,
            _ => LogLevel::default(), // Default to Error if no match
        }
    }
}

impl ExecutorConfig {
    pub fn builder() -> ExecutorConfigBuilder {
        ExecutorConfigBuilder::default()
    }

    pub fn new(model_path: String) -> Self {
        Self {
            model_path,
            log_level: LogLevel::Error,
            enable_chunked_context: None,
            normalize_log_probs: None,
            iter_stats_max_iterations: None,
            tensor_parallel_size: None,
        }
    }
}