config.rs 8.23 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

//! NIXL backend configuration with Figment support.
//!
//! This module provides configuration extraction for NIXL backends from
7
//! environment variables with the pattern: `DYN_KVBM_NIXL_BACKEND_<backend>=<value>`
8
9

use anyhow::{Result, bail};
10
11
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
12
13
14
15
16
17
18
19
20
21
22

use dynamo_config::parse_bool;

/// Configuration for NIXL backends.
///
/// Supports extracting backend configurations from environment variables:
/// - `DYN_KVBM_NIXL_BACKEND_UCX=true` - Enable UCX backend with default params
/// - `DYN_KVBM_NIXL_BACKEND_GDS=false` - Explicitly disable GDS backend
/// - Valid values: true/false, 1/0, on/off, yes/no (case-insensitive)
/// - Invalid values (e.g., "maybe", "random") will cause an error
/// - Custom params (e.g., `DYN_KVBM_NIXL_BACKEND_UCX_PARAM1=value`) will cause an error
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
///
/// # Data Structure
///
/// Uses a single HashMap where:
/// - Key presence = backend is enabled
/// - Value (inner HashMap) = backend-specific parameters (empty = defaults)
///
/// # TOML Example
///
/// ```toml
/// [backends.UCX]
/// # UCX with default params (empty map)
///
/// [backends.GDS]
/// threads = "4"
/// buffer_size = "1048576"
/// ```
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41
pub struct NixlBackendConfig {
42
43
44
45
46
47
48
    /// Map of backend name (uppercase) -> optional parameters.
    ///
    /// If a backend is present in the map, it's enabled.
    /// The inner HashMap contains optional override parameters.
    /// An empty inner map means use default parameters.
    #[serde(default)]
    backends: HashMap<String, HashMap<String, String>>,
49
50
51
}

impl NixlBackendConfig {
52
53
54
55
56
    /// Creates a new configuration with the given backends.
    ///
    /// For an empty configuration with no backends, use [`Default::default()`].
    pub fn new(backends: HashMap<String, HashMap<String, String>>) -> Self {
        Self { backends }
57
58
59
60
61
62
63
64
65
66
67
    }

    /// Create configuration from environment variables.
    ///
    /// Extracts backends from `DYN_KVBM_NIXL_BACKEND_<backend>=<value>` variables.
    ///
    /// # Errors
    /// Returns an error if:
    /// - Custom parameters are detected (not yet supported)
    /// - Invalid boolean values are provided (must be truthy or falsey)
    pub fn from_env() -> Result<Self> {
68
        let mut backends = HashMap::new();
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

        // Extract all environment variables that match our pattern
        for (key, value) in std::env::vars() {
            if let Some(remainder) = key.strip_prefix("DYN_KVBM_NIXL_BACKEND_") {
                // Check if there's an underscore (indicating custom params)
                if remainder.contains('_') {
                    bail!(
                        "Custom NIXL backend parameters are not yet supported. \
                         Found: {}. Please use only DYN_KVBM_NIXL_BACKEND_<backend>=true \
                         to enable backends with default parameters.",
                        key
                    );
                }

                // Simple backend enablement (e.g., DYN_KVBM_NIXL_BACKEND_UCX=true)
                let backend_name = remainder.to_uppercase();
                match parse_bool(&value) {
                    Ok(true) => {
87
                        backends.insert(backend_name, HashMap::new());
88
89
90
91
92
93
94
95
96
97
98
99
100
                    }
                    Ok(false) => {
                        // Explicitly disabled, don't add to backends
                        continue;
                    }
                    Err(e) => bail!("Invalid value for {}: {}", key, e),
                }
            }
        }

        Ok(Self { backends })
    }

101
102
    /// Add a backend with default parameters.
    /// Backend name is normalized to uppercase.
103
    pub fn with_backend(mut self, backend: impl Into<String>) -> Self {
104
105
        self.backends
            .insert(backend.into().to_uppercase(), HashMap::new());
106
107
108
        self
    }

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    /// Add a backend with custom parameters.
    /// Backend name is normalized to uppercase.
    pub fn with_backend_params(
        mut self,
        backend: impl Into<String>,
        params: HashMap<String, String>,
    ) -> Self {
        self.backends.insert(backend.into().to_uppercase(), params);
        self
    }

    /// Get the list of enabled backend names (uppercase).
    pub fn backends(&self) -> Vec<String> {
        self.backends.keys().cloned().collect()
    }

    /// Get parameters for a specific backend.
    /// Backend name is normalized to uppercase for lookup.
    ///
    /// Returns None if the backend is not enabled.
    pub fn backend_params(&self, backend: &str) -> Option<&HashMap<String, String>> {
        self.backends.get(&backend.to_uppercase())
131
132
133
134
    }

    /// Check if a specific backend is enabled.
    pub fn has_backend(&self, backend: &str) -> bool {
135
        self.backends.contains_key(&backend.to_uppercase())
136
137
138
139
140
    }

    /// Merge another configuration into this one.
    ///
    /// Backends from the other configuration will be added to this one.
141
    /// If both have the same backend, params from `other` take precedence.
142
143
144
145
    pub fn merge(mut self, other: NixlBackendConfig) -> Self {
        self.backends.extend(other.backends);
        self
    }
146
147
148
149
150

    /// Iterate over all enabled backends and their parameters.
    pub fn iter(&self) -> impl Iterator<Item = (&String, &HashMap<String, String>)> {
        self.backends.iter()
    }
151
152
153
154
155
156
157
158
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_new_config_is_empty() {
159
160
161
162
163
164
165
166
        let config = NixlBackendConfig::default();
        assert_eq!(config.backends().len(), 0);
    }

    #[test]
    fn test_default_is_empty() {
        let config = NixlBackendConfig::default();
        assert!(config.backends().is_empty()); // default() has no backends
167
168
169
170
    }

    #[test]
    fn test_with_backend() {
171
        let config = NixlBackendConfig::default()
172
173
174
175
176
177
178
179
180
181
            .with_backend("ucx")
            .with_backend("gds_mt");

        assert!(config.has_backend("ucx"));
        assert!(config.has_backend("UCX"));
        assert!(config.has_backend("gds_mt"));
        assert!(config.has_backend("GDS_MT"));
        assert!(!config.has_backend("other"));
    }

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    #[test]
    fn test_with_backend_params() {
        let mut params = HashMap::new();
        params.insert("threads".to_string(), "4".to_string());
        params.insert("buffer_size".to_string(), "1048576".to_string());

        let config = NixlBackendConfig::default()
            .with_backend("UCX")
            .with_backend_params("GDS", params);

        // UCX should have empty params
        let ucx_params = config.backend_params("UCX").unwrap();
        assert!(ucx_params.is_empty());

        // GDS should have custom params
        let gds_params = config.backend_params("GDS").unwrap();
        assert_eq!(gds_params.get("threads"), Some(&"4".to_string()));
        assert_eq!(gds_params.get("buffer_size"), Some(&"1048576".to_string()));
    }

202
203
    #[test]
    fn test_merge_configs() {
204
205
        let config1 = NixlBackendConfig::default().with_backend("ucx");
        let config2 = NixlBackendConfig::default().with_backend("gds");
206
207
208
209
210
211
212
213
214

        let merged = config1.merge(config2);

        assert!(merged.has_backend("ucx"));
        assert!(merged.has_backend("gds"));
    }

    #[test]
    fn test_backend_name_case_insensitive() {
215
        let config = NixlBackendConfig::default()
216
217
218
219
220
221
222
223
224
225
226
227
            .with_backend("ucx")
            .with_backend("Gds_mt")
            .with_backend("OTHER");

        assert!(config.has_backend("UCX"));
        assert!(config.has_backend("ucx"));
        assert!(config.has_backend("GDS_MT"));
        assert!(config.has_backend("gds_mt"));
        assert!(config.has_backend("OTHER"));
        assert!(config.has_backend("other"));
    }

228
229
230
231
232
233
234
235
236
237
238
239
240
    #[test]
    fn test_iter() {
        let mut params = HashMap::new();
        params.insert("key".to_string(), "value".to_string());

        let config = NixlBackendConfig::default()
            .with_backend("UCX")
            .with_backend_params("GDS", params);

        let items: Vec<_> = config.iter().collect();
        assert_eq!(items.len(), 2);
    }

241
242
243
    // Note: Testing from_env() would require setting environment variables,
    // which is challenging in unit tests. This is better tested with integration tests.
}