mtmd-audio.cpp 18.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#define _USE_MATH_DEFINES // for M_PI
#include "mtmd-audio.h"

#include <cmath>
#include <cstdint>
#include <cstring>
#include <thread>
#include <vector>
#include <fstream>
#include <algorithm>

// most of the code here is copied from whisper.cpp

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
constexpr bool DEBUG = false;

struct mtmd_audio_mel_filters {
    int32_t n_mel;
    int32_t n_fft;

    std::vector<float> data;
};

// note: this global cache is shared among all preprocessors
//       if we want to use multiple preprocessors at the same time,
//       we will need to enclose it in the preprocessor class in the future
static struct mtmd_audio_global_cache {
    // precomputed sin/cos table for FFT
    std::vector<float> sin_vals;
    std::vector<float> cos_vals;

    // hann window
    std::vector<float> hann_window;

    // mel filter bank
    mtmd_audio_mel_filters filters;

    void fill_sin_cos_table(int n) {
        sin_vals.resize(n);
        cos_vals.resize(n);
        for (int i = 0; i < n; i++) {
            double theta = (2 * M_PI * i) / n;
42
43
44
45
46
            sin_vals[i] = sinf(theta);
            cos_vals[i] = cosf(theta);
        }
    }

47
48
    void fill_hann_window(int length, bool periodic) {
        hann_window.resize(length);
49
50
51
52
53
        int offset = -1;
        if (periodic) {
            offset = 0;
        }
        for (int i = 0; i < length; i++) {
54
            hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
55
56
        }
    }
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
    // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
    void fill_mel_filterbank_matrix(
        int n_mel,
        int n_fft,
        int sample_rate,            // e.g. 16000
        float fmin = 0.0f,          // e.g. 0.0
        float fmax = -1.0f,         // e.g. sr/2; pass -1 for auto
        bool slaney_area_norm = true,
        float scale = 1.0f          // optional extra scaling; use 1.0f/1000.0f to mimic your code
    ) {
        GGML_ASSERT(n_mel > 0 && n_fft > 1);
        if (fmax <= 0.0f) {
            fmax = 0.5f * sample_rate;
        }

        // Slaney scale (matches librosa default)
        const double min_log_hz = 1000.0;
        const double lin_slope = 3 / 200.;
        const double min_log_mel = min_log_hz * lin_slope;
        const double log_step = log(6.4) / 27.0;
        auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double {
            return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step;
        };
        auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double {
            return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step);
        };

        // infer N_fft from n_fft_bins
        const double bin_hz_step = double(sample_rate) / double(n_fft);

        // mel grid: n_mel + 2 edges
        const double m_lo = hz_to_mel(fmin);
        const double m_hi = hz_to_mel(fmax);
        std::vector<double> mel_pts(n_mel + 2);
        for (int i = 0; i < n_mel + 2; ++i) {
            mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1));
        }

        // convert to Hz
        std::vector<double> hz_pts(n_mel + 2);
        for (int i = 0; i < n_mel + 2; ++i) {
            hz_pts[i] = mel_to_hz(mel_pts[i]);
        }

        const int n_fft_bins = n_fft / 2 + 1;

        // filterbank
        std::vector<float> out(n_mel * n_fft_bins, 0);
        for (int m = 0; m < n_mel; ++m) {
            const double f_left   = hz_pts[m];
            const double f_center = hz_pts[m + 1];
            const double f_right  = hz_pts[m + 2];

            const double denom_l = std::max(1e-30, f_center - f_left);
            const double denom_r = std::max(1e-30, f_right  - f_center);
            const double enorm   = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0;

            for (int k = 0; k < n_fft_bins; ++k) {
                const double f = k * bin_hz_step;
                double w = 0.0;
                if (f >= f_left && f <= f_center) {
                    w = (f - f_left) / denom_l;
                } else if (f > f_center && f <= f_right) {
                    w = (f_right - f) / denom_r;
                }
                out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale);
            }
        }

        filters.n_mel = n_mel;
        filters.n_fft = n_fft;
        filters.data  = std::move(out);

        if (DEBUG) { // debug
            for (size_t i = 0; i < filters.data.size(); ++i) {
                if (filters.data[i] != 0.0f) {
                    printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f);
                }
            }
        }
    }
} g_cache;
141
142
143
144

// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
145
146
147
static void dft(const float * in, int N, float * out) {
    const int n_sin_cos_vals = g_cache.sin_vals.size();
    const int sin_cos_step = n_sin_cos_vals / N;
148
149
150
151
152
153

    for (int k = 0; k < N; k++) {
        float re = 0;
        float im = 0;

        for (int n = 0; n < N; n++) {
154
155
156
            int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N
            re += in[n] * g_cache.cos_vals[idx]; // cos(t)
            im -= in[n] * g_cache.sin_vals[idx]; // sin(t)
157
158
159
160
161
162
163
164
165
166
167
        }

        out[k*2 + 0] = re;
        out[k*2 + 1] = im;
    }
}

// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
168
169
static void fft(float * in, int N, float * out) {
    const int n_sin_cos_vals = g_cache.sin_vals.size();
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    if (N == 1) {
        out[0] = in[0];
        out[1] = 0;
        return;
    }

    const int half_N = N / 2;
    if (N - half_N*2 == 1) {
        dft(in, N, out);
        return;
    }

    float* even = in + N;
    for (int i = 0; i < half_N; ++i) {
        even[i]= in[2*i];
    }
    float* even_fft = out + 2 * N;
    fft(even, half_N, even_fft);

    float* odd = even;
    for (int i = 0; i < half_N; ++i) {
        odd[i] = in[2*i + 1];
    }
    float* odd_fft = even_fft + N;
    fft(odd, half_N, odd_fft);

196
    const int sin_cos_step = n_sin_cos_vals / N;
197
198
    for (int k = 0; k < half_N; k++) {
        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
199
200
        float re =  g_cache.cos_vals[idx]; // cos(t)
        float im = -g_cache.sin_vals[idx]; // sin(t)
201
202
203
204
205
206
207
208
209
210
211
212

        float re_odd = odd_fft[2*k + 0];
        float im_odd = odd_fft[2*k + 1];

        out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
        out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;

        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
    }
}

213
214
215
216
217
218
219
220
221
222
223
224
struct filter_params {
    int32_t n_mel;
    int32_t n_fft_bins;
    int32_t hann_window_size;
    int32_t hop_length;
    int32_t sample_rate;
    bool    center_padding = false;
    float   preemph = 0.f;
    bool    use_natural_log = false;
    bool    norm_per_feature = false;
};

225
226
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
                                              int n_samples, int frame_size, int frame_step, int n_threads,
227
                                              const filter_params & params, mtmd_audio_mel & out) {
228
229
230
    std::vector<float> fft_in(frame_size * 2, 0.0);
    std::vector<float> fft_out(frame_size * 2 * 2 * 2);

231
    int n_fft_bins = params.n_fft_bins;
232
233
    int i = ith;

234
    const auto & filters = g_cache.filters;
235

236
237
238
    // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
    GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2));
    GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size());
239
    // calculate FFT only when fft_in are not all zero
240
    for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) {
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        const int offset = i * frame_step;

        // apply Hann window (~10% faster)
        for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
            fft_in[j] = hann[j] * samples[offset + j];
        }

        // fill the rest with zeros
        if (n_samples - offset < frame_size) {
            std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
        }

        // FFT
        fft(fft_in.data(), frame_size, fft_out.data());

        // Calculate modulus^2 of complex numbers
        // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
258
        for (int j = 0; j < n_fft_bins; j++) {
259
260
261
262
            fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
        }

        // mel spectrogram
263
        for (int j = 0; j < out.n_mel; j++) {
264
265
266
            double sum = 0.0;
            // unroll loop (suggested by GH user @lunixbochs)
            int k = 0;
267
268
            for (k = 0; k < n_fft_bins - 3; k += 4) {
                size_t idx = size_t(j) * size_t(n_fft_bins) + size_t(k);
269
                sum +=
270
271
272
273
                        fft_out[k + 0] * filters.data[idx + 0] +
                        fft_out[k + 1] * filters.data[idx + 1] +
                        fft_out[k + 2] * filters.data[idx + 2] +
                        fft_out[k + 3] * filters.data[idx + 3];
274
275
            }
            // handle n_fft remainder
276
277
            for (; k < n_fft_bins; k++) {
                sum += fft_out[k] * filters.data[j * n_fft_bins + k];
278
            }
279
280
281
282
            sum = params.use_natural_log
                ? log(sum + 5.960464477539063e-08)
                : log10(std::max(sum, 1e-10));
            out.data[j * out.n_len + i] = sum;
283
284
285
286
        }
    }

    // Otherwise fft_out are all zero
287
288
289
290
    double sum = params.use_natural_log ? log(1e-10) : log10(1e-10);
    for (; i < out.n_len; i += n_threads) {
        for (int j = 0; j < out.n_mel; j++) {
            out.data[j * out.n_len + i] = sum;
291
292
293
294
295
296
297
        }
    }
}

// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
        const float * samples,
298
299
300
301
        const int     n_samples_in,
        const int     n_threads,
        const filter_params & params,
        mtmd_audio_mel & out) {
302
303
    //const int64_t t_start_us = ggml_time_us();

304
305
    out.n_len_org = n_samples_in;
    int n_samples = n_samples_in;
306

307
308
309
310
    // Hann window
    const float * hann = g_cache.hann_window.data();
    const int frame_size = (params.n_fft_bins - 1) * 2;
    const int frame_step = params.hop_length;
311

312
    // Padding
313
    std::vector<float> samples_padded;
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    if (params.center_padding) {
        const auto pad_amount = frame_size / 2;
        samples_padded = std::vector<float>(n_samples + 2 * pad_amount, 0);
        std::copy(samples, samples + n_samples, samples_padded.data() + pad_amount);
        samples = samples_padded.data();
        n_samples = samples_padded.size();
    } else {
        // existing padding logic
        int64_t stage_1_pad = params.sample_rate * 30;
        int64_t stage_2_pad = frame_size / 2;
        samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
        std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
        // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
        std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
        // reflective pad 200 samples at the beginning of audio
        if (n_samples < stage_2_pad + 1) {
            // TODO: Handle short audio differently or return error
            return false;
        }
        std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
    }
335

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    // preemphasis
    if (params.preemph) {
        const int pad_amount = frame_size / 2;
        const float preemph = 0.97f;
        float prev = samples_padded[pad_amount];
        for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) {
            float cur = samples_padded[i];
            samples_padded[i] = cur - preemph * prev;
            prev = cur;
        }
    }

    // pad hann window if it's smaller than frame_size
    // TODO: probably unnecessary here? (or better doing it in g_cache?)
    std::vector<float> hann_window_padded;
    if (params.hann_window_size < frame_size) {
        hann_window_padded.resize(frame_size);
        const int padding = (frame_size - params.hann_window_size) / 2;
        std::copy(hann, hann + params.hann_window_size, &hann_window_padded[padding]);
        hann = hann_window_padded.data();
    }
357
358


359
360
361
362
363
364
365
366
367
368
369
370
    out.n_mel = params.n_mel;
    out.n_len = (n_samples - frame_size) / frame_step + 1;
    // TODO: handle these checks better
    if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) {
        LOG_ERR("%s: size overflow\n", __func__);
        return false;
    }
    if (n_samples < frame_size) {
        LOG_ERR("%s: not enough samples after padding\n", __func__);
        return false;
    }
    out.data.resize(out.n_mel * out.n_len);
371
372
373
374
375
376

    {
        std::vector<std::thread> workers(n_threads - 1);
        for (int iw = 0; iw < n_threads - 1; ++iw) {
            workers[iw] = std::thread(
                    log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
377
378
                    n_samples, frame_size, frame_step, n_threads,
                    std::cref(params), std::ref(out));
379
380
381
        }

        // main thread
382
        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out);
383
384
385
386
387
        for (int iw = 0; iw < n_threads - 1; ++iw) {
            workers[iw].join();
        }
    }

388
389
390
391
392
393
394
395
    const int effective_n_len = n_samples_in / frame_step;
    if (params.norm_per_feature) {
        for (int i = 0; i < out.n_mel; i++) {
            double mean = 0;
            for (int j = 0; j < effective_n_len; ++j) {
                mean += out.data[i * out.n_len + j];
            }
            mean /= effective_n_len;
396

397
398
399
400
401
402
403
404
405
406
407
408
            double var = 0.0;
            for (int j = 0; j < effective_n_len; ++j) {
                const double value = out.data[i * out.n_len + j] - mean;
                var += value * value;
            }
            var /= effective_n_len - 1;  // unbiased
            const double mstd = std::sqrt(var + 1e-5);

            for (int j = 0; j < effective_n_len; ++j) {
                auto &value = out.data[i * out.n_len + j];
                value = (value - mean) / mstd;
            }
409

410
411
412
413
414
415
416
417
418
419
420
421
            // pad the rest with zeros
            for (int j = effective_n_len; j < out.n_len; ++j) {
                out.data[i * out.n_len + j] = 0.0;
            }
        }
    } else {
        // clamping and normalization
        double mmax = -1e20;
        for (int i = 0; i < out.n_mel*out.n_len; i++) {
            if (out.data[i] > mmax) {
                mmax = out.data[i];
            }
422
423
        }

424
425
426
427
428
429
430
431
        mmax -= 8.0;

        for (int i = 0; i < out.n_mel*out.n_len; i++) {
            if (out.data[i] < mmax) {
                out.data[i] = mmax;
            }
            out.data[i] = (out.data[i] + 4.0)/4.0;
        }
432
433
434
    }

    // Dump log_mel_spectrogram
435
    if (DEBUG) {
436
437
        std::ofstream outFile("log_mel_spectrogram.json");
        outFile << "[";
438
439
        for (uint64_t i = 0; i < out.data.size() - 1; i++) {
            outFile << out.data[i] << ", ";
440
        }
441
        outFile << out.data[out.data.size() - 1] << "]";
442
443
444
445
446
447
        outFile.close();
    }

    return true;
}

448
449
450
451
452
453
454
455
456
457
458
459
460
461
//
// mtmd_audio_preprocessor_whisper
//

void mtmd_audio_preprocessor_whisper::initialize() {
    g_cache.fill_sin_cos_table(hparams.audio_n_fft);
    g_cache.fill_hann_window(hparams.audio_window_len, true);
    g_cache.fill_mel_filterbank_matrix(
        hparams.n_mel_bins,
        hparams.audio_n_fft,
        hparams.audio_sample_rate);
}

bool mtmd_audio_preprocessor_whisper::preprocess(
462
463
        const float * samples,
        size_t n_samples,
464
        std::vector<mtmd_audio_mel> & output) {
465
466
467
468
469
    if (n_samples == 0) {
        // empty audio
        return false;
    }

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    std::vector<float> smpl;
    // if input is too short, pad with zeros
    // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram
    // TODO: maybe handle this better
    size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin
    if (n_samples < min_samples) {
        smpl.resize(min_samples, 0.0f);
        std::memcpy(smpl.data(), samples, n_samples * sizeof(float));
        samples   = smpl.data();
        n_samples = smpl.size();
    }

    filter_params params;
    params.n_mel            = hparams.n_mel_bins;
    params.n_fft_bins       = 1 + (hparams.audio_n_fft / 2);
    params.hann_window_size = hparams.audio_window_len;
    params.hop_length       = hparams.audio_hop_len;
    params.sample_rate      = hparams.audio_sample_rate;
    params.center_padding   = false;
    params.preemph          = 0.0f; // disabled
    params.use_natural_log  = false;
    params.norm_per_feature = false;

    // make sure the global cache is initialized
    GGML_ASSERT(!g_cache.sin_vals.empty());
    GGML_ASSERT(!g_cache.cos_vals.empty());
    GGML_ASSERT(!g_cache.filters.data.empty());

    mtmd_audio_mel out_full;
499
500
501
502
    bool ok = log_mel_spectrogram(
                samples,
                n_samples,
                4, // n_threads
503
                params,
504
505
506
507
508
509
510
                out_full);
    if (!ok) {
        return false;
    }

    // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
    // we always expect the mel to have 3000 silent frames at the end
511
512
513
    if (DEBUG) {
        printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len);
    }
514
515
516
517
518
519
520
521
    const size_t frames_per_chunk = 3000;
    GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
    for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
        int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
        if ((size_t)n_len < frames_per_chunk) {
            break; // last uncomplete chunk will always be a padded chunk, safe to ignore
        }

522
        mtmd_audio_mel out_chunk;
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        out_chunk.n_len     = n_len;
        out_chunk.n_mel     = out_full.n_mel;
        out_chunk.n_len_org = out_full.n_mel; // unused
        out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);

        for (int i = 0; i < out_full.n_mel; i++) {
            auto src = out_full.data.begin() + i*out_full.n_len + off;
            out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
        }

        output.push_back(std::move(out_chunk));
    }

    return true;
}