Unverified Commit 332530ba authored by Rickard's avatar Rickard Committed by GitHub
Browse files

quantize_block C->C++, use std::thread everywhere (#1024)

parent 8c507d92
#include <common.h> #include <common.h>
#include <float.h> #include <float.h>
void *quantize_block(void *arguments) { void quantize_block(const quantize_block_args& args) {
// 1. find absmax in block // 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0] // 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value // 3. do binary search to find the closest value
// 4. check minimal distance // 4. check minimal distance
// 5. store index // 5. store index
struct quantize_block_args *args = (quantize_block_args *) arguments;
// 1. find absmax in block // 1. find absmax in block
float absmax_block = -FLT_MAX; float absmax_block = -FLT_MAX;
for (long long i = args->block_idx; i < args->block_end; i++) for (long long i = args.block_idx; i < args.block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i])); absmax_block = fmax(absmax_block, fabs(args.A[i]));
args->absmax[args->block_idx / args->blocksize] = absmax_block; args.absmax[args.block_idx / args.blocksize] = absmax_block;
for (long long i = args->block_idx; i < args->block_end; i++) { for (long long i = args.block_idx; i < args.block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0] // 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value // 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block; float normed_value = args.A[i] / absmax_block;
long long idx = args->bin_searcher->scalar(normed_value); long long idx = args.bin_searcher->scalar(normed_value);
// 4. check minimal distance // 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value // The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) { if (idx < 255) {
float dist_left = fabs(normed_value - (args->code[idx])); float dist_left = fabs(normed_value - (args.code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1])); float dist_right = fabs(normed_value - (args.code[idx + 1]));
if (dist_right < dist_left) { idx += 1; } if (dist_right < dist_left) { idx += 1; }
} }
// 5. store index // 5. store index
args->out[i] = (unsigned char) idx; args.out[i] = (unsigned char) idx;
} }
return NULL;
} }
...@@ -20,6 +20,6 @@ struct quantize_block_args { ...@@ -20,6 +20,6 @@ struct quantize_block_args {
}; };
void *quantize_block(void *arguments); void quantize_block(const quantize_block_args& args);
#endif #endif
#include <BinSearch.h> #include <BinSearch.h>
#ifdef _WIN32
#include <thread>
#else
#include <pthread.h>
#endif
#include <common.h> #include <common.h>
#include <thread>
using namespace BinSearch; using namespace BinSearch;
...@@ -30,21 +26,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long ...@@ -30,21 +26,13 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
int thread_wave_size = 256; int thread_wave_size = 256;
// we chunk the thresds into waves of 256 since the max limit is // we chunk the threads into waves of 256 since the max limit is
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{ {
long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
#ifdef _WIN32 std::vector<std::thread> threads(valid_chunks);
std::thread *threads = (std::thread *) malloc(sizeof(std::thread) * valid_chunks); std::vector<quantize_block_args> args(valid_chunks);
#else
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
#endif
struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
for(long long i = 0; i < valid_chunks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
int chunks_processed = 0; int chunks_processed = 0;
for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
...@@ -52,39 +40,24 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long ...@@ -52,39 +40,24 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
long long block_end = block_idx + valid_items; long long block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[chunks_processed]; struct quantize_block_args& arg = args[chunks_processed];
arg->bin_searcher = &bin_searcher; arg.bin_searcher = &bin_searcher;
arg->code = code; arg.code = code;
arg->A = A; arg.A = A;
arg->absmax = absmax; arg.absmax = absmax;
arg->out = out; arg.out = out;
arg->block_end = block_end; arg.block_end = block_end;
arg->block_idx = block_idx; arg.block_idx = block_idx;
arg->threadidx = block_idx / blocksize; arg.threadidx = block_idx / blocksize;
arg->blocksize = blocksize; arg.blocksize = blocksize;
#ifdef _WIN32 threads[chunks_processed] = std::thread([arg] { quantize_block(arg); });
new (&threads[chunks_processed]) std::thread(quantize_block, arg);
#else
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
#endif
chunks_processed += 1; chunks_processed += 1;
if(chunks_processed == valid_chunks){ break; } if(chunks_processed == valid_chunks){ break; }
} }
for (int i = 0; i < valid_chunks; i++) for (int i = 0; i < valid_chunks; i++)
{
#ifdef _WIN32
threads[i].join(); threads[i].join();
#else
int err = pthread_join(threads[i], NULL);
#endif
}
free(threads);
for (int i = 0; i < valid_chunks; i++)
free(args[i]);
free(args);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment