config.rs 6.79 KB
Newer Older
1
2
3
4
5
6
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub struct LlavaNext {
7
8
9
    pub(crate) text_config: TextConfig,
    pub(crate) vision_config: VisionConfig,
    pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
}

fn get_anyres_image_grid_shape(
    height: usize,
    width: usize,
    grid_pinpoints: &[(usize, usize)],
    patch_size: usize,
) -> (usize, usize) {
    let (height, width) = select_best_resolution(height, width, grid_pinpoints);
    (height / patch_size, width / patch_size)
}

/// Selects the best resolution from a list of possible resolutions based on the original size.
/// This is done by calculating the effective and wasted resolution for each possible resolution.
/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
fn select_best_resolution(
    original_height: usize,
    original_width: usize,
    possible_resolutions: &[(usize, usize)],
) -> (usize, usize) {
    let mut best_fit = None;
    let mut max_effective_resolution = 0;
    let mut min_wasted_resolution = f32::NEG_INFINITY;

    for (height, width) in possible_resolutions {
        let wscale = *width as f32 / original_width as f32;
        let hscale = *height as f32 / original_height as f32;
        // f32 partial ord.
        let scale = if wscale > hscale { hscale } else { wscale };
        let downscaled_width = (*width as f32 * scale) as usize;
        let downscaled_height = (*height as f32 * scale) as usize;
        let effective_resolution = std::cmp::min(
            downscaled_width * downscaled_height,
            original_width * original_height,
        );
        let wasted_resolution = (width * height) - effective_resolution;

        if effective_resolution > max_effective_resolution
            || (effective_resolution == max_effective_resolution
                && (wasted_resolution as f32) < min_wasted_resolution)
        {
            max_effective_resolution = effective_resolution;
            min_wasted_resolution = wasted_resolution as f32;
            best_fit = Some((*height, *width));
        }
    }

    best_fit.unwrap_or((original_height, original_width))
}

Nicolas Patry's avatar
Nicolas Patry committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
fn get_unpadded_features(
    height: usize,
    width: usize,
    npatches: usize,
    num_patch_height: usize,
    num_patch_width: usize,
) -> (usize, usize) {
    let current_height = npatches * num_patch_height;
    let current_width = npatches * num_patch_width;

    let aspect_ratio: f64 = width as f64 / height as f64;
    let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
    let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
        let new_height = (height * current_width) / width;
74
75
        let padding = (current_height - new_height) / 2;
        (current_height - (2 * padding), current_width)
Nicolas Patry's avatar
Nicolas Patry committed
76
77
    } else {
        let new_width = (width * current_height) / height;
78
79
        let padding = (current_width - new_width) / 2;
        (current_height, current_width - (2 * padding))
Nicolas Patry's avatar
Nicolas Patry committed
80
81
82
83
84
85
86
    };

    let unpadded_features = current_height * current_width;
    let newline_features = current_height;
    (unpadded_features, newline_features)
}

87
88
89
90
91
92
impl LlavaNext {
    pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
        let image_size = self.vision_config.image_size;
        let patch_size = self.vision_config.patch_size;
        assert!(image_size % patch_size == 0);
        let npatches = image_size / patch_size;
93
94
95
        // Dimensions are intentionally swapped to be bug-compatible with
        // upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
        let (num_patch_width, num_patch_height) =
96
            get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
Nicolas Patry's avatar
Nicolas Patry committed
97
98
99

        let (unpadded_features, newline_features) =
            get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
100
101
102
103
104
105
106
107
108
109
110
111
112
        // The base patch covers the entire image
        let base_features = npatches.pow(2);
        unpadded_features + newline_features + base_features
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct ClipVisionModel {
    image_size: usize,
    patch_size: usize,
}

Nicolas Patry's avatar
Nicolas Patry committed
113
114
115
116
117
118
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Idefics2 {}

impl Idefics2 {
    pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
119
        64
Nicolas Patry's avatar
Nicolas Patry committed
120
121
122
    }
}

drbh's avatar
drbh committed
123
124
125
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct PaliTextConfig {
126
    pub(crate) num_image_tokens: usize,
drbh's avatar
drbh committed
127
128
129
130
131
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Paligemma {
132
    pub(crate) text_config: PaliTextConfig,
drbh's avatar
drbh committed
133
134
135
136
137
138
139
140
}

impl Paligemma {
    pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
        self.text_config.num_image_tokens
    }
}

141
142
143
144
145
146
147
148
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "model_type")]
#[serde(rename_all = "snake_case")]
pub enum Config {
    LlavaNext(LlavaNext),
    ClipVisionModel(ClipVisionModel),
    Mistral,
    Idefics,
Nicolas Patry's avatar
Nicolas Patry committed
149
    Idefics2(Idefics2),
150
151
152
153
154
    Ssm,
    GptBigcode,
    Santacoder,
    Bloom,
    Mpt,
155
    Gpt2,
156
157
158
159
    GptNeox,
    Phi,
    #[serde(rename = "phi-msft")]
    PhiMsft,
160
    Phi3,
161
162
    Llama,
    Baichuan,
drbh's avatar
drbh committed
163
    Paligemma(Paligemma),
164
    Gemma,
165
    Gemma2,
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    Cohere,
    Drbx,
    Falcon,
    Mixtral,
    Starcoder2,
    Qwen2,
    Opt,
    T5,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct TextConfig {}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct VisionConfig {
183
184
    pub(crate) image_size: usize,
    pub(crate) patch_size: usize,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_llava_next_features() {
        let config = LlavaNext {
            text_config: TextConfig {},
            vision_config: VisionConfig {
                image_size: 336,
                patch_size: 14,
            },
            image_grid_pinpoints: vec![
                (336, 672),
                (672, 336),
                (672, 672),
                (1008, 336),
                (336, 1008),
            ],
        };

Nicolas Patry's avatar
Nicolas Patry committed
208
209
        let slots = config.get_number_of_features(20, 20);
        assert_eq!(slots, 1176);
210
211
212
213
214
        let slots = config.get_number_of_features(640, 640);
        assert_eq!(slots, 2928);
        let slots = config.get_number_of_features(480, 640);
        assert_eq!(slots, 2340);
        let slots = config.get_number_of_features(899, 1024);
Nicolas Patry's avatar
Nicolas Patry committed
215
        assert_eq!(slots, 2634);
216
        let slots = config.get_number_of_features(1024, 899);
Nicolas Patry's avatar
Nicolas Patry committed
217
218
219
        assert_eq!(slots, 2640);
        let slots = config.get_number_of_features(1067, 1600);
        assert_eq!(slots, 2144);
220
221
    }
}