cpu_ops.cpp 2.67 KB
Newer Older
Max Ryabinin's avatar
Max Ryabinin committed
1
2
3
4
5
6
#include <BinSearch.h>
#include <pthread.h>
#include <common.h>

using namespace BinSearch;

7
8
9
10
11
12
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
    for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
        long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
        long long block_end = block_idx + valid_items;
        for (long long i = block_idx; i < block_end; i++)
            out[i] = code[A[i]] * absmax[block_idx / blocksize];
Max Ryabinin's avatar
Max Ryabinin committed
13
14
15
    }
}

16
17
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
{
Max Ryabinin's avatar
Max Ryabinin committed
18
19
20
21

    // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
    code[0] = -1.0f;

22
23
    long long num_blocks = n / blocksize;
    num_blocks += n % blocksize == 0 ? 0 : 1;
Max Ryabinin's avatar
Max Ryabinin committed
24
25
26
27

    const uint32 elements_code = 256;
    BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);

28
29
30
31
32
33
34
35
36
37
38
    int thread_wave_size = 256;
    // we chunk the thresds into waves of 256 since the max limit is
    // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
    for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
    {
      pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size);

      struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *));

      for(long long i = 0; i < thread_wave_size; i++)
          args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
Max Ryabinin's avatar
Max Ryabinin committed
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
      int chunks_processed = 0;
      for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
      {
          long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
          long long block_end = block_idx + valid_items;

          struct quantize_block_args *arg = args[chunks_processed];
          arg->bin_searcher = &bin_searcher;
          arg->code = code;
          arg->A = A;
          arg->absmax = absmax;
          arg->out = out;
          arg->block_end = block_end;
          arg->block_idx = block_idx;
          arg->threadidx = block_idx / blocksize;
          arg->blocksize = blocksize;

          pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
          chunks_processed += 1;
          if(chunks_processed == thread_wave_size){ break; }
      }

      for (int i = 0; i < thread_wave_size; i++)
          int err = pthread_join(threads[i], NULL);
      
      free(threads);
      for (int i = 0; i < thread_wave_size; i++)
          free(args[i]);
      free(args);

    }
Max Ryabinin's avatar
Max Ryabinin committed
71

Max Ryabinin's avatar
Max Ryabinin committed
72
}