l0.rs 5.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
//! L0 Cache: Whole-string exact match cache
//!
//! This is the simplest and most effective cache layer.
//! Key: input string → Value: full encoding result
//!
//! Expected hit rate: 60-90% for workloads with repeated system prompts

use std::sync::{
    atomic::{AtomicU64, Ordering},
    Arc,
};

use dashmap::DashMap;

use super::super::traits::Encoding;

/// L0 cache implementation using DashMap for lock-free reads
pub struct L0Cache {
    /// The cache map: input string → encoding
    map: Arc<DashMap<String, Encoding>>,
    /// Maximum number of entries before eviction
    max_entries: usize,
    /// Cache hit counter
    hits: AtomicU64,
    /// Cache miss counter
    misses: AtomicU64,
}

impl L0Cache {
    /// Create a new L0 cache with the specified capacity
    pub fn new(max_entries: usize) -> Self {
        Self {
            map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
            max_entries,
            hits: AtomicU64::new(0),
            misses: AtomicU64::new(0),
        }
    }

    /// Get an encoding from the cache
    pub fn get(&self, key: &str) -> Option<Encoding> {
        match self.map.get(key) {
            Some(entry) => {
                self.hits.fetch_add(1, Ordering::Relaxed);
                Some(entry.value().clone())
            }
            None => {
                self.misses.fetch_add(1, Ordering::Relaxed);
                None
            }
        }
    }

    /// Insert an encoding into the cache
    pub fn insert(&self, key: String, value: Encoding) {
        // Simple eviction: if we're at capacity, remove a random entry
        // DashMap doesn't support LRU directly, so we use a simple strategy
        if self.map.len() >= self.max_entries {
            // Get the key to remove in a separate scope to ensure iterator is dropped
            let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) }; // Iterator fully dropped here, all locks released

            // Now remove it
            if let Some(k) = key_to_remove {
                self.map.remove(&k);
            }
        }

        self.map.insert(key, value);
    }

    /// Get the current number of entries in the cache
    pub fn len(&self) -> usize {
        self.map.len()
    }

    /// Check if the cache is empty
    pub fn is_empty(&self) -> bool {
        self.map.is_empty()
    }

    /// Get cache statistics
    pub fn stats(&self) -> CacheStats {
        let hits = self.hits.load(Ordering::Relaxed);
        let misses = self.misses.load(Ordering::Relaxed);
        let total_requests = hits + misses;

        CacheStats {
            hits,
            misses,
            entries: self.len(),
            hit_rate: if total_requests > 0 {
                hits as f64 / total_requests as f64
            } else {
                0.0
            },
        }
    }

    /// Clear the cache
    pub fn clear(&self) {
        self.map.clear();
        self.hits.store(0, Ordering::Relaxed);
        self.misses.store(0, Ordering::Relaxed);
    }

    /// Estimate memory usage in bytes
    pub fn memory_usage(&self) -> usize {
        // Rough estimate:
        // - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead
        // - Average: ~2.2KB per entry
        self.len() * 2200
    }
}

#[derive(Debug, Clone)]
pub struct CacheStats {
    pub hits: u64,
    pub misses: u64,
    pub entries: usize,
    pub hit_rate: f64,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tokenizer::traits::Encoding;

    fn mock_encoding(tokens: Vec<u32>) -> Encoding {
        Encoding::Sp(tokens)
    }

    #[test]
    fn test_basic_get_set() {
        let cache = L0Cache::new(10);

        // Miss
        assert!(cache.get("hello").is_none());

        // Insert
        cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));

        // Hit
        let result = cache.get("hello");
        assert!(result.is_some());
        assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
    }

    #[test]
    fn test_eviction() {
        let cache = L0Cache::new(2);

        cache.insert("a".to_string(), mock_encoding(vec![1]));
        cache.insert("b".to_string(), mock_encoding(vec![2]));

        // Should evict when adding third
        cache.insert("c".to_string(), mock_encoding(vec![3]));

        // Cache should have exactly 2 entries
        assert_eq!(cache.len(), 2);
    }

    #[test]
    fn test_stats() {
        let cache = L0Cache::new(10);

        cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));

        // 1 miss (initial get that returned None)
        let _ = cache.get("missing");

        // 1 hit
        let _ = cache.get("test");

        let stats = cache.stats();
        assert_eq!(stats.hits, 1);
        assert_eq!(stats.misses, 1);
        assert_eq!(stats.hit_rate, 0.5);
    }

    #[test]
    fn test_clear() {
        let cache = L0Cache::new(10);

        cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
        assert_eq!(cache.len(), 1);

        cache.clear();
        assert_eq!(cache.len(), 0);
        assert!(cache.get("test").is_none());
    }

    #[test]
    fn test_concurrent_access() {
        use std::thread;

        let cache = Arc::new(L0Cache::new(1000));
        let mut handles = vec![];

        // Spawn 10 threads
        for i in 0..10 {
            let cache_clone = cache.clone();
            handles.push(thread::spawn(move || {
                // Each thread inserts and reads
                let key = format!("key_{}", i);
                cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));

                // Read it back
                let result = cache_clone.get(&key);
                assert!(result.is_some());
            }));
        }

        for handle in handles {
            handle.join().unwrap();
        }

        // Should have 10 entries
        assert_eq!(cache.len(), 10);
    }
}