cpu_ops.cpp 2.45 KB
Newer Older
Max Ryabinin's avatar
Max Ryabinin committed
1
2
#include <BinSearch.h>
#include <common.h>
3
#include <thread>
Max Ryabinin's avatar
Max Ryabinin committed
4
5
6

using namespace BinSearch;

7
void dequantize_cpu(float* code, unsigned char* A, float* absmax, float* out, long long blocksize, long long n) {
8
9
10
11
12
    for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
        long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
        long long block_end = block_idx + valid_items;
        for (long long i = block_idx; i < block_end; i++)
            out[i] = code[A[i]] * absmax[block_idx / blocksize];
Max Ryabinin's avatar
Max Ryabinin committed
13
14
15
    }
}

16
void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n) {
Max Ryabinin's avatar
Max Ryabinin committed
17
18
19
20

    // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
    code[0] = -1.0f;

21
22
    long long num_blocks = n / blocksize;
    num_blocks += n % blocksize == 0 ? 0 : 1;
Max Ryabinin's avatar
Max Ryabinin committed
23
24
25
26

    const uint32 elements_code = 256;
    BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);

27
    int thread_wave_size = 256;
28
    // we chunk the threads into waves of 256 since the max limit is
29
    // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    for (long long offset = 0; offset < num_blocks; offset += thread_wave_size) {
        long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
        std::vector<std::thread> threads(valid_chunks);
        std::vector<quantize_block_args> args(valid_chunks);

        int chunks_processed = 0;
        for (long long block_idx = offset * blocksize; block_idx < n; block_idx += blocksize) {
            long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
            long long block_end = block_idx + valid_items;

            struct quantize_block_args& arg = args[chunks_processed];
            arg.bin_searcher = &bin_searcher;
            arg.code = code;
            arg.A = A;
            arg.absmax = absmax;
            arg.out = out;
            arg.block_end = block_end;
            arg.block_idx = block_idx;
            arg.threadidx = block_idx / blocksize;
            arg.blocksize = blocksize;

            threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
            chunks_processed += 1;
            if (chunks_processed == valid_chunks) {
                break;
            }
        }

        for (int i = 0; i < valid_chunks; i++)
            threads[i].join();
60
    }
Max Ryabinin's avatar
Max Ryabinin committed
61
}