config.rs 6.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
    text_config: TextConfig,
    vision_config: VisionConfig,
    image_grid_pinpoints: Vec<(usize, usize)>,
}

fn get_anyres_image_grid_shape(
    height: usize,
    width: usize,
    grid_pinpoints: &[(usize, usize)],
    patch_size: usize,
) -> (usize, usize) {
    let (height, width) = select_best_resolution(height, width, grid_pinpoints);
    (height / patch_size, width / patch_size)
}

/// Selects the best resolution from a list of possible resolutions based on the original size.
/// This is done by calculating the effective and wasted resolution for each possible resolution.
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
fn select_best_resolution(
    original_height: usize,
    original_width: usize,
    possible_resolutions: &[(usize, usize)],
) -> (usize, usize) {
    let mut best_fit = None;
    let mut max_effective_resolution = 0;
    let mut min_wasted_resolution = f32::NEG_INFINITY;

    for (height, width) in possible_resolutions {
        let wscale = *width as f32 / original_width as f32;
        let hscale = *height as f32 / original_height as f32;
        // f32 partial ord.
        let scale = if wscale > hscale { hscale } else { wscale };
        let downscaled_width = (*width as f32 * scale) as usize;
        let downscaled_height = (*height as f32 * scale) as usize;
        let effective_resolution = std::cmp::min(
            downscaled_width * downscaled_height,
            original_width * original_height,
        );
        let wasted_resolution = (width * height) - effective_resolution;

        if effective_resolution > max_effective_resolution
            || (effective_resolution == max_effective_resolution
                && (wasted_resolution as f32) < min_wasted_resolution)
        {
            max_effective_resolution = effective_resolution;
            min_wasted_resolution = wasted_resolution as f32;
            best_fit = Some((*height, *width));
        }
    }

    best_fit.unwrap_or((original_height, original_width))
}

Nicolas Patry's avatar
Nicolas Patry committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
fn get_unpadded_features(
    height: usize,
    width: usize,
    npatches: usize,
    num_patch_height: usize,
    num_patch_width: usize,
) -> (usize, usize) {
    let current_height = npatches * num_patch_height;
    let current_width = npatches * num_patch_width;

    let aspect_ratio: f64 = width as f64 / height as f64;
    let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
    let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
        let new_height = (height * current_width) / width;
        (new_height, current_width)
    } else {
        let new_width = (width * current_height) / height;
        (current_height, new_width)
    };

    let unpadded_features = current_height * current_width;
    let newline_features = current_height;
    (unpadded_features, newline_features)
}

85
86
87
88
89
90
91
92
impl LlavaNext {
    pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
        let image_size = self.vision_config.image_size;
        let patch_size = self.vision_config.patch_size;
        assert!(image_size % patch_size == 0);
        let npatches = image_size / patch_size;
        let (num_patch_height, num_patch_width) =
            get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
Nicolas Patry's avatar
Nicolas Patry committed
93
94
95

        let (unpadded_features, newline_features) =
            get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        // The base patch covers the entire image
        let base_features = npatches.pow(2);
        unpadded_features + newline_features + base_features
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
    image_size: usize,
    patch_size: usize,
}

Nicolas Patry's avatar
Nicolas Patry committed
110
111
112
113
114
115
116
117
118
119
120
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}

impl Idefics2 {
    pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
        320
    }
}

121
122
123
124
125
126
127
128
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
    LlavaNext(LlavaNext),
    ClipVisionModel(ClipVisionModel),
    Mistral,
    Idefics,
Nicolas Patry's avatar
Nicolas Patry committed
129
    Idefics2(Idefics2),
130
131
132
133
134
    Ssm,
    GptBigcode,
    Santacoder,
    Bloom,
    Mpt,
135
    Gpt2,
136
137
138
139
    GptNeox,
    Phi,
    #[serde(rename = "phi-msft")]
    PhiMsft,
140
    Phi3,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    Llama,
    Baichuan,
    Gemma,
    Cohere,
    Drbx,
    Falcon,
    Mixtral,
    Starcoder2,
    Qwen2,
    Opt,
    T5,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextConfig {}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
    image_size: usize,
    patch_size: usize,
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_llava_next_features() {
        let config = LlavaNext {
            text_config: TextConfig {},
            vision_config: VisionConfig {
                image_size: 336,
                patch_size: 14,
            },
            image_grid_pinpoints: vec![
                (336, 672),
                (672, 336),
                (672, 672),
                (1008, 336),
                (336, 1008),
            ],
        };

Nicolas Patry's avatar
Nicolas Patry committed
186
187
        let slots = config.get_number_of_features(20, 20);
        assert_eq!(slots, 1176);
188
189
190
191
192
        let slots = config.get_number_of_features(640, 640);
        assert_eq!(slots, 2928);
        let slots = config.get_number_of_features(480, 640);
        assert_eq!(slots, 2340);
        let slots = config.get_number_of_features(899, 1024);
Nicolas Patry's avatar
Nicolas Patry committed
193
        assert_eq!(slots, 2634);
194
        let slots = config.get_number_of_features(1024, 899);
Nicolas Patry's avatar
Nicolas Patry committed
195
196
197
        assert_eq!(slots, 2640);
        let slots = config.get_number_of_features(1067, 1600);
        assert_eq!(slots, 2144);
198
199
    }
}