tree.rs 5.78 KB
Newer Older
1
2
3
use std::collections::HashMap;
use std::mem;

4
#[derive(Debug)]
5
pub struct Node {
6
7
8
    pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
    pub ids: Vec<u32>,
    pub count: u32,
9
10
}

11
#[derive(Debug)]
12
13
14
15
pub struct RadixTree {
    pub root: Node,
}

16
fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    let mut i = 0;
    while i < a.len() && i < b.len() && a[i] == b[i] {
        i += 1;
    }
    i
}

impl Default for RadixTree {
    fn default() -> Self {
        Self::new()
    }
}

impl RadixTree {
    pub fn new() -> Self {
        RadixTree {
            root: Node {
                children: HashMap::new(),
                ids: Vec::new(),
                count: 0,
            },
        }
    }

41
    pub fn insert(&mut self, input_ids: &[u32]) {
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        let mut curr = &mut self.root;
        curr.count += 1;

        let mut curr_idx = 0;
        let input_ids_len = input_ids.len();

        while curr_idx < input_ids_len {
            let first_id = &input_ids[curr_idx];
            // TODO: changing this get_mut causes error
            if curr.children.contains_key(first_id) {
                let child = curr.children.get_mut(first_id).unwrap();

                let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);

                if prefix_len == child.ids.len() {
                    // move curr to child
                    curr = child;
                    curr.count += 1;
                    curr_idx += prefix_len;
                } else {
                    // split child
                    // [child]->... => [child]->[new child]->...
                    let new_child = Node {
                        // to avoid clone: replace child.children with default value (empty vector) and return the original value
                        children: mem::take(&mut child.children),
                        ids: child.ids[prefix_len..].to_vec(),
                        count: child.count,
                    };

                    child.ids = child.ids[..prefix_len].to_vec();
                    child.children = HashMap::new();
                    child.children.insert(new_child.ids[0], new_child);

                    curr = child;
                    curr.count += 1;
                    curr_idx += prefix_len;
                }
            } else {
                // create new child
                let new_child = Node {
                    children: HashMap::new(),
                    ids: input_ids[curr_idx..].to_vec(),
                    count: 0,
                };

                let first_id = new_child.ids[0];
                curr.children.insert(first_id, new_child);

                curr = curr.children.get_mut(&first_id).unwrap();
                curr.count += 1;
                curr_idx = input_ids_len;
            }
        }
    }

97
    pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        let mut curr = &self.root;

        let mut curr_idx = 0;
        let input_ids_len = input_ids.len();

        while curr_idx < input_ids_len {
            match curr.children.get(&input_ids[curr_idx]) {
                Some(child) => {
                    let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);

                    if prefix_len == child.ids.len() {
                        curr_idx += prefix_len;
                        curr = child;
                    } else {
                        curr_idx += prefix_len;
                        break;
                    }
                }
                None => {
                    break;
                }
            }
        }

        &input_ids[..curr_idx]
    }

125
    pub fn delete(&mut self, input_ids: &[u32]) {
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        let mut curr = &mut self.root;
        curr.count -= 1;

        let mut curr_idx = 0;
        let input_ids_len = input_ids.len();

        while curr_idx < input_ids_len {
            let first_id = &input_ids[curr_idx];

            if curr.children.contains_key(first_id) {
                let child = curr.children.get(first_id).unwrap();

                let prefix_len = common_prefix_len(&input_ids[curr_idx..], &child.ids);

                if prefix_len == child.ids.len() {
                    if child.count == 1 {
                        // If count will become 0, remove the child
                        let child = curr.children.get_mut(first_id).unwrap();
                        child.count -= 1;
                        curr.children.remove(first_id);
                        break;
                    } else {
                        // Otherwise decrement count and continue
                        let child = curr.children.get_mut(first_id).unwrap();

                        child.count -= 1;
                        curr = child;
                        curr_idx += prefix_len;
                    }
                } else {
                    panic!("No match found for {:?}", input_ids);
                }
            } else {
                panic!("No match found for {:?}", input_ids);
            }
        }
    }

    // for debug
    pub fn pretty_print(&self) {
        println!("RadixTree:");
        Self::print_node(&self.root, String::from(""));
    }

    fn print_node(node: &Node, prefix: String) {
        // Print current node info with "count" word
        println!("{}└── {:?} (count: {})", prefix, node.ids, node.count);

        // Print children with proper prefixes
        for (i, child) in node.children.values().enumerate() {
            let is_last = i == node.children.len() - 1;
            let child_prefix = if is_last {
                format!("{}    ", prefix) // Add space for last child
            } else {
                format!("{}│   ", prefix) // Add vertical line for other children
            };
            Self::print_node(child, child_prefix);
        }
    }
}