cpu_ops.cpp 3.08 KB
Newer Older
Max Ryabinin's avatar
Max Ryabinin committed
1
#include <BinSearch.h>
James Wyatt's avatar
James Wyatt committed
2
3
4
#ifdef _WIN32
#include <thread>
#else
Max Ryabinin's avatar
Max Ryabinin committed
5
#include <pthread.h>
James Wyatt's avatar
James Wyatt committed
6
#endif
Max Ryabinin's avatar
Max Ryabinin committed
7
8
9
10
#include <common.h>

using namespace BinSearch;

11
12
13
14
15
16
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
    for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
        long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
        long long block_end = block_idx + valid_items;
        for (long long i = block_idx; i < block_end; i++)
            out[i] = code[A[i]] * absmax[block_idx / blocksize];
Max Ryabinin's avatar
Max Ryabinin committed
17
18
19
    }
}

20
21
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
{
Max Ryabinin's avatar
Max Ryabinin committed
22
23
24
25

    // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
    code[0] = -1.0f;

26
27
    long long num_blocks = n / blocksize;
    num_blocks += n % blocksize == 0 ? 0 : 1;
Max Ryabinin's avatar
Max Ryabinin committed
28
29
30
31

    const uint32 elements_code = 256;
    BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);

32
33
34
35
36
    int thread_wave_size = 256;
    // we chunk the thresds into waves of 256 since the max limit is
    // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
    for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
    {
37
      long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
James Wyatt's avatar
James Wyatt committed
38
39
40
#ifdef _WIN32
      std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks);
#else
41
      pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
James Wyatt's avatar
James Wyatt committed
42
#endif
43

44
      struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
45

46
      for(long long i = 0; i < valid_chunks; i++)
47
          args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
Max Ryabinin's avatar
Max Ryabinin committed
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
      int chunks_processed = 0;
      for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
      {
          long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
          long long block_end = block_idx + valid_items;

          struct quantize_block_args *arg = args[chunks_processed];
          arg->bin_searcher = &bin_searcher;
          arg->code = code;
          arg->A = A;
          arg->absmax = absmax;
          arg->out = out;
          arg->block_end = block_end;
          arg->block_idx = block_idx;
          arg->threadidx = block_idx / blocksize;
          arg->blocksize = blocksize;

James Wyatt's avatar
James Wyatt committed
66
67
68
#ifdef _WIN32
          new (&threads[chunks_processed]) std::thread(quantize_block, arg);
#else
69
          pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
James Wyatt's avatar
James Wyatt committed
70
#endif
71
          chunks_processed += 1;
72
          if(chunks_processed == valid_chunks){ break; }
73
74
      }

75
      for (int i = 0; i < valid_chunks; i++)
James Wyatt's avatar
James Wyatt committed
76
77
78
79
      {
#ifdef _WIN32
          threads[i].join();
#else
80
          int err = pthread_join(threads[i], NULL);
James Wyatt's avatar
James Wyatt committed
81
82
#endif
      }
83
      free(threads);
84
      for (int i = 0; i < valid_chunks; i++)
85
86
87
88
          free(args[i]);
      free(args);

    }
Max Ryabinin's avatar
Max Ryabinin committed
89

Max Ryabinin's avatar
Max Ryabinin committed
90
}