"components/backends/trtllm/engine_configs/decode.yaml" did not exist on "86bc5442b4171d9a7c3de4b854dd07ca1b7a4f65"
cache.rs 4.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use std::path::PathBuf;

use dynamo_runtime::config::environment_names::llm;

#[derive(Clone)]
pub struct LoRACache {
    cache_root: PathBuf,
}

impl LoRACache {
    pub fn new(cache_root: PathBuf) -> Self {
        Self { cache_root }
    }

    /// Get cache path from DYN_LORA_PATH environment variable.
    /// Defaults to `$HOME/.cache/dynamo_loras` if not set.
    pub fn from_env() -> Result<Self> {
        let cache_root = std::env::var(llm::DYN_LORA_PATH).unwrap_or_else(|_| {
            // Use $HOME/.cache/dynamo_loras as default, fallback to /tmp if HOME is not set
            let home = std::env::var("HOME")
                .or_else(|_| std::env::var("USERPROFILE"))
                .unwrap_or_else(|_| "/tmp".to_string());
            PathBuf::from(home)
                .join(".cache")
                .join("dynamo_loras")
                .to_string_lossy()
                .to_string()
        });
        Ok(Self::new(PathBuf::from(cache_root)))
    }

    /// Get local cache path for LoRA ID
    pub fn get_cache_path(&self, lora_id: &str) -> PathBuf {
        self.cache_root.join(lora_id)
    }

    /// Check if LoRA is cached
    pub fn is_cached(&self, lora_id: &str) -> bool {
        self.get_cache_path(lora_id).exists()
    }

46
47
48
49
50
51
52
    /// Convert a LoRA URI to a cache key.
    /// This is a static method to ensure consistent cache key generation
    /// across Rust and Python code.
    pub fn uri_to_cache_key(uri: &str) -> String {
        uri.replace("://", "__").replace(['/', '\\', '.'], "_")
    }

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    /// Validate cached LoRA has required files
    /// TODO: Add support for other weight file formats supported by trtllm
    pub fn validate_cached(&self, lora_id: &str) -> Result<bool> {
        let path = self.get_cache_path(lora_id);
        if !path.exists() {
            return Ok(false);
        }

        // Check for at least adapter_config.json
        let config_path = path.join("adapter_config.json");
        if !config_path.exists() {
            return Ok(false);
        }

        // Check for at least one weight file
        // TODO: Add support for other weight file formats supported by trtllm
        let has_weights = path.join("adapter_model.safetensors").exists()
            || path.join("adapter_model.bin").exists()
            || path.join("model.lora_weights.npy").exists();

        Ok(has_weights)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;
    use tempfile::TempDir;

    #[test]
    fn test_cache_creation() {
        let temp_dir = TempDir::new().unwrap();
        let cache = LoRACache::new(temp_dir.path().to_path_buf());
        assert_eq!(cache.cache_root, temp_dir.path());
    }

    #[test]
    fn test_get_cache_path() {
        let temp_dir = TempDir::new().unwrap();
        let cache = LoRACache::new(temp_dir.path().to_path_buf());
        let lora_path = cache.get_cache_path("my-lora");
        assert_eq!(lora_path, temp_dir.path().join("my-lora"));
    }

    #[test]
    fn test_is_cached() {
        let temp_dir = TempDir::new().unwrap();
        let cache = LoRACache::new(temp_dir.path().to_path_buf());

        // Create a lora directory
        let lora_dir = temp_dir.path().join("test-lora");
        fs::create_dir(&lora_dir).unwrap();

        assert!(cache.is_cached("test-lora"));
        assert!(!cache.is_cached("non-existent"));
    }

    #[test]
    fn test_validate_cached() {
        let temp_dir = TempDir::new().unwrap();
        let cache = LoRACache::new(temp_dir.path().to_path_buf());

        // Create a lora directory with required files
        let lora_dir = temp_dir.path().join("valid-lora");
        fs::create_dir(&lora_dir).unwrap();
        fs::write(lora_dir.join("adapter_config.json"), "{}").unwrap();
        fs::write(lora_dir.join("adapter_model.safetensors"), "").unwrap();

        assert!(cache.validate_cached("valid-lora").unwrap());

        // Test missing weight file
        let lora_dir2 = temp_dir.path().join("invalid-lora");
        fs::create_dir(&lora_dir2).unwrap();
        fs::write(lora_dir2.join("adapter_config.json"), "{}").unwrap();

        assert!(!cache.validate_cached("invalid-lora").unwrap());
    }
131
132
133
134
135
136
137
138
139
140
141
142

    #[test]
    fn test_uri_to_cache_key() {
        assert_eq!(
            LoRACache::uri_to_cache_key("s3://bucket/path/to/lora"),
            "s3__bucket_path_to_lora"
        );
        assert_eq!(
            LoRACache::uri_to_cache_key("file:///local/path"),
            "file___local_path"
        );
    }
143
}