Unverified Commit 564199ba authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: update exllamav2 kernels (#1370)


Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent 987c959f
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define _config_h #define _config_h
#define MAX_Q_GEMM_ROWS 50 #define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
#define QMODE_2BIT 1 #define QMODE_2BIT 1
#define QMODE_3BIT 1 #define QMODE_3BIT 1
...@@ -10,4 +11,5 @@ ...@@ -10,4 +11,5 @@
#define QMODE_6BIT 0 #define QMODE_6BIT 0
#define QMODE_8BIT 0 #define QMODE_8BIT 0
#endif #endif
...@@ -10,16 +10,19 @@ ...@@ -10,16 +10,19 @@
#include "quant/qdq_6.cuh" #include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh" #include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128 #define GPTQ_BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8 #define GPTQ_BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) #define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
#define EXL2_BLOCK_KN_SIZE 64
#define EXL2_BLOCK_M_SIZE_MAX 8
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
#define CLEAR_N_SIZE 256 #define CLEAR_N_SIZE 256
#include "q_gemm_kernel.cuh" #include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh" #include "q_gemm_kernel_gptq.cuh"
#include "compat_gemm.cuh"
void gemm_half_q_half_cuda_part void gemm_half_q_half_cuda_part
( (
const half* a, const half* a,
...@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part ...@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
int size_n, int size_n,
int size_k, int size_k,
int m_count, int m_count,
bool clear bool clear,
const half* r_weights,
int r_weights_stride,
bool mul_r_weights
) )
{ {
if (!b->is_gptq) if (!b->is_gptq)
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = EXL2_BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
blockDim.z = 1; blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count); gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count); fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim>>>
( (
...@@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part ...@@ -55,7 +61,7 @@ void gemm_half_q_half_cuda_part
size_n, size_n,
size_k, size_k,
b->groups, b->groups,
b->groupsize, b->cuda_q_group_map,
b->cuda_q_perm, b->cuda_q_perm,
b->rows_8, b->rows_8,
b->rows_6, b->rows_6,
...@@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part ...@@ -63,24 +69,27 @@ void gemm_half_q_half_cuda_part
b->rows_4, b->rows_4,
b->rows_3, b->rows_3,
b->rows_2, b->rows_2,
clear clear,
r_weights,
r_weights_stride
); );
} }
else else
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = GPTQ_BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
blockDim.z = 1; blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count); gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
// DBGX((uint64_t) b->cuda_q_perm); // DBGX((uint64_t) r_weights);
// DBGI(b->rows_4); // if (r_weights)
// DBGI(b->height); // print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel<<<gridDim, blockDim>>> kernel<<<gridDim, blockDim>>>
( (
...@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part ...@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
size_n, size_n,
size_k, size_k,
b->groups, b->groups,
b->groupsize, b->gptq_groupsize,
b->cuda_q_perm, b->cuda_q_perm,
b->rows_4, b->rows_4,
clear clear,
r_weights,
r_weights_stride
); );
} }
} }
...@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda ...@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
int size_k, int size_k,
bool clear, bool clear,
half* temp_dq, half* temp_dq,
bool force_cuda bool force_cuda,
const half* r_weights,
const int r_weights_stride,
bool mul_r_weights
) )
{ {
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
{ {
//printf("cublas\n");
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (!temp_dq) temp_dq = b->temp_dq; if (!temp_dq) temp_dq = b->temp_dq;
...@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda ...@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
//const float alpha = 1.0f; //const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f; //const float beta = clear ? 0.0f : 1.0f;
//cublasSgemmEx(cublas_handle, //cublasSgemmEx(cublas_handle,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// CUBLAS_OP_N, // CUBLAS_OP_N,
// size_n, size_m, size_k, // size_n, size_m, size_k,
// &alpha, temp_dq, CUDA_R_16F, size_n, // &alpha, temp_dq, CUDA_R_16F, size_n,
// a, CUDA_R_16F, size_k, // a, CUDA_R_16F, size_k,
// &beta, c, CUDA_R_16F, size_n); // &beta, c, CUDA_R_16F, size_n);
//const float alpha = 1.0f; //const float alpha = 1.0f;
//const float beta = clear ? 0.0f : 1.0f; //const float beta = clear ? 0.0f : 1.0f;
...@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda ...@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
} }
else else
{ {
//printf("cuda\n");
// Quantized matmul // Quantized matmul
//if (clear) clear_tensor_cuda(c, size_m, size_n); int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
int max_chunks = size_m / block_m_size_max;
int max_chunks = size_m / BLOCK_M_SIZE_MAX; int last_chunk = max_chunks * block_m_size_max;
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
int last_chunk_size = size_m - last_chunk; int last_chunk_size = size_m - last_chunk;
if (max_chunks) if (max_chunks)
{ {
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
} }
if (last_chunk_size) if (last_chunk_size)
{ {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
} }
} }
} }
...@@ -201,11 +210,10 @@ void clear_tensor_cuda ...@@ -201,11 +210,10 @@ void clear_tensor_cuda
int size_n int size_n
) )
{ {
return; // dim3 blockDim, gridDim;
dim3 blockDim, gridDim; // blockDim.x = CLEAR_N_SIZE;
blockDim.x = CLEAR_N_SIZE; // blockDim.y = 1;
blockDim.y = 1; // gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); // gridDim.y = size_m;
gridDim.y = size_m; // clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
} }
...@@ -20,7 +20,10 @@ void gemm_half_q_half_cuda ...@@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
int size_k, int size_k,
bool clear = false, bool clear = false,
half* reconstruct = NULL, half* reconstruct = NULL,
bool force_cuda = false bool force_cuda = false,
const half* r_weights = NULL,
const int r_weights_stride = 0,
bool mul_r_weights = false
); );
void clear_tensor_cuda void clear_tensor_cuda
......
...@@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) ...@@ -18,6 +18,15 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result)); return __half2float(__low2half(result)) + __half2float(__high2half(result));
} }
__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return result;
}
typedef void (*fp_gemm_half_q_half_gptq_kernel) typedef void (*fp_gemm_half_q_half_gptq_kernel)
( (
const half*, const half*,
...@@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel) ...@@ -32,10 +41,12 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int, const int,
const uint16_t*, const uint16_t*,
const int, const int,
const bool const bool,
const half*,
const int
); );
template <bool first_block, int m_count> template <int m_count, bool use_r_weights, bool mul_r_weights>
__global__ void gemm_half_q_half_gptq_kernel __global__ void gemm_half_q_half_gptq_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
...@@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -50,7 +61,9 @@ __global__ void gemm_half_q_half_gptq_kernel
const int groupsize, const int groupsize,
const uint16_t* __restrict__ b_q_perm, const uint16_t* __restrict__ b_q_perm,
const int rows_4, const int rows_4,
const bool clear const bool clear,
const half* r_weights,
const int r_weights_stride
) )
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
...@@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -62,19 +75,35 @@ __global__ void gemm_half_q_half_gptq_kernel
// Block // Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count; int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE; int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m); int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4; int n = offset_n + t * 4;
// Read weights
half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
if constexpr (use_r_weights)
{
uint16_t any_w = 0;
const half* w_ptr = r_weights;
for (int m = 0; m < m_count; ++m)
{
weights[m].as_half = *w_ptr;
w_ptr += r_weights_stride;
any_w |= weights[m].as_uint16;
}
if (!any_w) return; // Early exit if all weights are zero -- does not zero output (!!!)
}
// Preload block_a // Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE]; __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];
if (offset_k + t < end_k) if (offset_k + t < end_k)
{ {
...@@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -113,16 +142,16 @@ __global__ void gemm_half_q_half_gptq_kernel
const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0]; const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE; int a_stride = GPTQ_BLOCK_KN_SIZE;
// Initial group // Initial group
int zeros[4]; int zeros[4];
float scales[4]; half2 scales[4];
half2 z1z16[4][2]; half2 z1z16[4][2];
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -132,7 +161,7 @@ __global__ void gemm_half_q_half_gptq_kernel
// Column result // Column result
float block_c[m_count][4] = {}; half2 block_c[m_count][4] = {};
// Dequantize and multiply // Dequantize and multiply
...@@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -144,7 +173,7 @@ __global__ void gemm_half_q_half_gptq_kernel
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -166,10 +195,11 @@ __global__ void gemm_half_q_half_gptq_kernel
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
} }
b_ptr += size_n; b_ptr += size_n;
...@@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -182,38 +212,62 @@ __global__ void gemm_half_q_half_gptq_kernel
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
{ {
half2 *out = (half2*) c_.item_ptr(offset_m + m, n); half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));
half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));
half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));
half2 result01 = __halves2half2(result0, result1);
half2 result23 = __halves2half2(result2, result3);
if constexpr (mul_r_weights)
{
half2 w_mul2 = __half2half2(weights[m].as_half);
result01 = __hmul2(result01, w_mul2);
result23 = __hmul2(result23, w_mul2);
}
atomicAdd(out , result01); atomicAdd(out , result01);
atomicAdd(out + 1, result23); atomicAdd(out + 1, result23);
} }
} }
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) template <bool use_r_weights, bool mul_r_weights>
struct map_m_count_gptq {
static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)
{
#if GPTQ_BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;
#endif
#if GPTQ_BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;
#endif
return NULL;
}
};
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)
{ {
#if BLOCK_M_SIZE_MAX >= 1 if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>; if (!r_weights && mul_r_weights) return map_m_count_gptq<false, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
#endif if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);
#if BLOCK_M_SIZE_MAX >= 2 if ( r_weights && mul_r_weights) return map_m_count_gptq< true, true>::pick_gemm_half_q_half_gptq_kernel(m_count);
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL; return NULL;
} }
...@@ -57,6 +57,7 @@ QMatrix::QMatrix ...@@ -57,6 +57,7 @@ QMatrix::QMatrix
uint32_t* _q_scale, uint32_t* _q_scale,
half* _q_scale_max, half* _q_scale_max,
uint16_t* _q_groups, uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros, uint32_t* _gptq_qzeros,
half* _gptq_scales, half* _gptq_scales,
...@@ -80,13 +81,17 @@ QMatrix::QMatrix ...@@ -80,13 +81,17 @@ QMatrix::QMatrix
cuda_q_scale = _q_scale; cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max; cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups; cuda_q_groups = _q_groups;
cuda_q_group_map = _q_group_map;
cuda_gptq_qzeros = _gptq_qzeros; cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales; cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL); is_gptq = (_gptq_qzeros != NULL);
groupsize = 1; if (is_gptq)
while (groupsize * groups < height) groupsize *= 2; {
gptq_groupsize = 1;
while (gptq_groupsize * groups < height) gptq_groupsize *= 2;
}
// Create group map // Create group map
...@@ -102,15 +107,26 @@ QMatrix::QMatrix ...@@ -102,15 +107,26 @@ QMatrix::QMatrix
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t)); uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost); cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
int row = 0;
for (int i = 0; i < groups; i++) for (int i = 0; i < groups; i++)
{ {
int bits = cpu_q_groups[i * 2]; int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize; int rows;
if (bits == 5) rows_5 += groupsize; if (i < groups - 1)
if (bits == 4) rows_4 += groupsize; {
if (bits == 3) rows_3 += groupsize; int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];
if (bits == 2) rows_2 += groupsize; rows = qrows * 32 / bits;
}
else rows = height - row;
if (bits == 8) rows_8 += rows;
if (bits == 6) rows_6 += rows;
if (bits == 5) rows_5 += rows;
if (bits == 4) rows_4 += rows;
if (bits == 3) rows_3 += rows;
if (bits == 2) rows_2 += rows;
row += rows;
} }
free(cpu_q_groups); free(cpu_q_groups);
...@@ -138,6 +154,13 @@ QMatrix::QMatrix ...@@ -138,6 +154,13 @@ QMatrix::QMatrix
} }
} }
// DBGI(rows_8);
// DBGI(rows_6);
// DBGI(rows_5);
// DBGI(rows_4);
// DBGI(rows_3);
// DBGI(rows_2);
// Shuffle quantized data // Shuffle quantized data
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
...@@ -283,10 +306,10 @@ __global__ void reconstruct_kernel ...@@ -283,10 +306,10 @@ __global__ void reconstruct_kernel
const uint16_t* __restrict__ b_q_perm, const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale, const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max, const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups, const uint16_t* __restrict__ b_q_group_map,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groupsize, //const int groupsize,
const int groups, const int groups,
half* __restrict__ b, half* __restrict__ b,
const int rows_8, const int rows_8,
...@@ -317,7 +340,8 @@ __global__ void reconstruct_kernel ...@@ -317,7 +340,8 @@ __global__ void reconstruct_kernel
// Find initial group // Find initial group
int group = offset_k / groupsize; // int group = offset_k / groupsize;
int group = b_q_group_map[offset_k * 2];
int pre_rows_8 = min(rows_8, offset_k); int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0; int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
...@@ -337,7 +361,7 @@ __global__ void reconstruct_kernel ...@@ -337,7 +361,7 @@ __global__ void reconstruct_kernel
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h); half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize; int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k; int k = offset_k;
...@@ -347,7 +371,7 @@ __global__ void reconstruct_kernel ...@@ -347,7 +371,7 @@ __global__ void reconstruct_kernel
while (k < rows_8 && k < end_k) while (k < rows_8 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++) for (int p = 0; p < 4; p++)
{ {
half2 dq[4]; half2 dq[4];
...@@ -363,7 +387,7 @@ __global__ void reconstruct_kernel ...@@ -363,7 +387,7 @@ __global__ void reconstruct_kernel
while (k < rows_6 && k < end_k) while (k < rows_6 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++) for (int p = 0; p < 2; p++)
{ {
half2 dq[8]; half2 dq[8];
...@@ -380,7 +404,7 @@ __global__ void reconstruct_kernel ...@@ -380,7 +404,7 @@ __global__ void reconstruct_kernel
while (k < rows_5 && k < end_k) while (k < rows_5 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[16]; half2 dq[16];
...@@ -399,7 +423,7 @@ __global__ void reconstruct_kernel ...@@ -399,7 +423,7 @@ __global__ void reconstruct_kernel
while (k < rows_4 && k < end_k) while (k < rows_4 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++) for (int p = 0; p < 4; p++)
{ {
half2 dq[4]; half2 dq[4];
...@@ -414,7 +438,7 @@ __global__ void reconstruct_kernel ...@@ -414,7 +438,7 @@ __global__ void reconstruct_kernel
while (k < rows_3 && k < end_k) while (k < rows_3 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[16]; half2 dq[16];
...@@ -431,8 +455,8 @@ __global__ void reconstruct_kernel ...@@ -431,8 +455,8 @@ __global__ void reconstruct_kernel
while (k < rows_2 && k < end_k) while (k < rows_2 && k < end_k)
{ {
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); } if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++) for (int p = 0; p < 1; p++)
{ {
half2 dq[8]; half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n; uint32_t q_0 = *b_ptr; b_ptr += size_n;
...@@ -441,7 +465,7 @@ __global__ void reconstruct_kernel ...@@ -441,7 +465,7 @@ __global__ void reconstruct_kernel
half* dqh = (half*) dq; half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]); for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
} }
k += 32; k += 16;
} }
} }
...@@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out) ...@@ -461,10 +485,10 @@ void QMatrix::reconstruct(half* out)
cuda_q_perm, cuda_q_perm,
cuda_q_scale, cuda_q_scale,
cuda_q_scale_max, cuda_q_scale_max,
//cuda_q_groups, cuda_q_group_map,
height, height,
width, width,
groupsize, //groupsize,
groups, groups,
out, out,
rows_8, rows_8,
...@@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out) ...@@ -487,7 +511,7 @@ void QMatrix::reconstruct(half* out)
//const uint16_t* __restrict__ b_q_groups, //const uint16_t* __restrict__ b_q_groups,
height, height,
width, width,
groupsize, gptq_groupsize,
groups, groups,
out, out,
rows_4 rows_4
......
...@@ -18,7 +18,7 @@ public: ...@@ -18,7 +18,7 @@ public:
int height; int height;
int width; int width;
int groups; int groups;
int groupsize; int gptq_groupsize;
int rows_8; int rows_8;
int rows_6; int rows_6;
...@@ -33,6 +33,7 @@ public: ...@@ -33,6 +33,7 @@ public:
uint32_t* cuda_q_scale = NULL; uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL; half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL; uint16_t* cuda_q_groups = NULL;
uint16_t* cuda_q_group_map = NULL;
uint32_t* cuda_gptq_qzeros = NULL; uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL; half* cuda_gptq_scales = NULL;
...@@ -53,6 +54,7 @@ public: ...@@ -53,6 +54,7 @@ public:
uint32_t* _q_scale, uint32_t* _q_scale,
half* _q_scale_max, half* _q_scale_max,
uint16_t* _q_groups, uint16_t* _q_groups,
uint16_t* _q_group_map,
uint32_t* _gptq_qzeros, uint32_t* _gptq_qzeros,
half* _gptq_scales, half* _gptq_scales,
......
...@@ -7,6 +7,7 @@ union half2_uint32 ...@@ -7,6 +7,7 @@ union half2_uint32
half2 as_half2; half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {} __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {} __device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32() : as_uint32(0) {}
}; };
union half_uint16 union half_uint16
...@@ -15,6 +16,7 @@ union half_uint16 ...@@ -15,6 +16,7 @@ union half_uint16
half as_half; half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {} __device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {} __device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16() : as_uint16(0) {}
}; };
// Max_scale premultiplied by 1/256 // Max_scale premultiplied by 1/256
......
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
...@@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort= ...@@ -40,3 +48,7 @@ inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=
if (abort) exit(code); if (abort) exit(code);
} }
} }
void print_global_mem(const half* ptr, int rows, int columns, int stride);
#endif
\ No newline at end of file
...@@ -31,6 +31,7 @@ uintptr_t make_q_matrix ...@@ -31,6 +31,7 @@ uintptr_t make_q_matrix
torch::Tensor q_scale, torch::Tensor q_scale,
torch::Tensor q_scale_max, torch::Tensor q_scale_max,
torch::Tensor q_groups, torch::Tensor q_groups,
torch::Tensor q_group_map,
torch::Tensor gptq_qzeros, torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales, torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx, torch::Tensor gptq_g_idx,
...@@ -43,6 +44,7 @@ uintptr_t make_q_matrix ...@@ -43,6 +44,7 @@ uintptr_t make_q_matrix
TORCH_CHECK_DTYPE_OPT(q_scale, kInt); TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort); TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
...@@ -83,12 +85,15 @@ uintptr_t make_q_matrix ...@@ -83,12 +85,15 @@ uintptr_t make_q_matrix
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr() (half*) temp_dq.data_ptr()
); );
if (m->failed) throw std::runtime_error("CUDA out of memory");
return reinterpret_cast<uintptr_t> (m); return reinterpret_cast<uintptr_t> (m);
} }
......
...@@ -32,10 +32,10 @@ def fresh_cache(): ...@@ -32,10 +32,10 @@ def fresh_cache():
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d os.environ["HUGGINGFACE_HUB_CACHE"] = d
yield yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
...@@ -47,7 +47,7 @@ def prefetched(): ...@@ -47,7 +47,7 @@ def prefetched():
revision="main", revision="main",
local_files_only=False, local_files_only=False,
repo_type="model", repo_type="model",
allow_patterns=["*.safetensors"] allow_patterns=["*.safetensors"],
) )
yield model_id yield model_id
...@@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache): ...@@ -61,7 +61,7 @@ def test_weight_hub_files_offline_error(offline, fresh_cache):
def test_weight_hub_files_offline_ok(prefetched, offline): def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache # If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched) filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors'] assert filenames == ["model.safetensors"]
def test_weight_hub_files(): def test_weight_hub_files():
......
...@@ -71,7 +71,7 @@ def _load_multi_mqa_gptq( ...@@ -71,7 +71,7 @@ def _load_multi_mqa_gptq(
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device) g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_params() bits, groupsize, _ = weights._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
......
...@@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda): ...@@ -27,6 +27,32 @@ def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
return output.view(output_shape) return output.view(output_shape)
# Group map needed for irregular group sizes
def make_group_map(q_groups, num_qrows):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(w: dict, temp_dq, key: str = None): def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
""" """
Create Q matrix Create Q matrix
...@@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ...@@ -37,6 +63,10 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale_max"] /= 256 w["q_scale_max"] /= 256
w["q_perm"] = w["q_perm"].short() w["q_perm"] = w["q_perm"].short()
w["q_invperm"] = w["q_invperm"].short() w["q_invperm"] = w["q_invperm"].short()
if "q_group_map" not in w:
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
return make_q_matrix( return make_q_matrix(
w["q_weight"], w["q_weight"],
w["q_perm"], w["q_perm"],
...@@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ...@@ -44,6 +74,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["q_scale"], w["q_scale"],
w["q_scale_max"], w["q_scale_max"],
w["q_groups"], w["q_groups"],
w["q_group_map"],
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
...@@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ...@@ -70,6 +101,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor,
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
w["g_idx"].cpu(), w["g_idx"].cpu(),
...@@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): ...@@ -84,6 +116,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor, none_tensor,
none_tensor,
w["qzeros"], w["qzeros"],
w["scales"], w["scales"],
none_tensor, none_tensor,
......
...@@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) ...@@ -18,7 +18,9 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory""" """Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision) d = _get_cached_revision_directory(model_id, revision)
if not d: if not d:
...@@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) ...@@ -27,7 +29,9 @@ def _cached_weight_files(model_id: str, revision: Optional[str], extension: str)
return filenames return filenames
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: def _weight_hub_files_from_model_info(
info: hf_api.ModelInfo, extension: str
) -> List[str]:
return [ return [
s.rfilename s.rfilename
for s in info.siblings for s in info.siblings
...@@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: ...@@ -44,21 +48,27 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# see _weight_hub_files_from_model_info, that's also what is # see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition # done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d))) root, _, files = next(os.walk(str(d)))
filenames = [f for f in files filenames = [
if f.endswith(extension) f
and "arguments" not in f for f in files
and "args" not in f if f.endswith(extension)
and "adapter" not in f and "arguments" not in f
and "training" not in f] and "args" not in f
and "adapter" not in f
and "training" not in f
]
return filenames return filenames
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]:
if revision is None: if revision is None:
revision = "main" revision = "main"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model")) file_download.repo_folder_name(repo_id=model_id, repo_type="model")
)
if not repo_cache.is_dir(): if not repo_cache.is_dir():
# No cache for this model # No cache for this model
...@@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op ...@@ -86,7 +96,7 @@ def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Op
def weight_hub_files( def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]: ) -> List[str]:
"""Get the weights filenames on the hub""" """Get the weights filenames on the hub"""
api = HfApi() api = HfApi()
......
...@@ -19,6 +19,7 @@ from accelerate import init_empty_weights ...@@ -19,6 +19,7 @@ from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils.log import log_once
HAS_AWQ = True HAS_AWQ = True
try: try:
...@@ -35,10 +36,11 @@ HAS_EXLLAMA = False ...@@ -35,10 +36,11 @@ HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning( V2 = False
log_once(
logger.warning,
"Disabling exllama v2 and using v1 instead because there are issues when sharding" "Disabling exllama v2 and using v1 instead because there are issues when sharding"
) )
V2 = False
if os.getenv("DISABLE_EXLLAMA") == "True": if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False HAS_EXLLAMA = False
......
from functools import lru_cache
@lru_cache(10)
def log_once(log, msg:str):
log(msg)
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from loguru import logger from loguru import logger
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
from text_generation_server.utils.log import log_once
class Weights: class Weights:
...@@ -161,7 +162,7 @@ class Weights: ...@@ -161,7 +162,7 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else: else:
slice_ = self._get_slice(f"{prefix}.weight") slice_ = self._get_slice(f"{prefix}.weight")
...@@ -211,10 +212,10 @@ class Weights: ...@@ -211,10 +212,10 @@ class Weights:
else: else:
g_idx = None g_idx = None
bits, groupsize = self._get_gptq_params() bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" use_exllama = bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
...@@ -240,11 +241,15 @@ class Weights: ...@@ -240,11 +241,15 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True use_exllama = True
bits, groupsize = self._get_gptq_params() bits, groupsize, desc_act = self._get_gptq_params()
if bits != 4: if bits != 4:
use_exllama = False use_exllama = False
if desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
if self.process_group.size() > 1: if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx") g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None: if g_idx is not None:
...@@ -274,12 +279,18 @@ class Weights: ...@@ -274,12 +279,18 @@ class Weights:
if use_exllama: if use_exllama:
if not HAS_EXLLAMA: if not HAS_EXLLAMA:
if CAN_EXLLAMA: if CAN_EXLLAMA:
logger.warning( log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True" "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True"
) )
use_exllama = False use_exllama = False
else: else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") log_once(
logger.info,
f"Using exllama kernels v{HAS_EXLLAMA}"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1: if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
...@@ -288,14 +299,12 @@ class Weights: ...@@ -288,14 +299,12 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama: if use_exllama:
g_idx = g_idx - g_idx[0] g_idx = g_idx - g_idx[0]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":
bits, groupsize = self._get_gptq_params() bits, groupsize, _ = self._get_gptq_params()
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
...@@ -314,18 +323,20 @@ class Weights: ...@@ -314,18 +323,20 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_params(self) -> Tuple[int, int]: def _get_gptq_params(self) -> Tuple[int, int, int]:
try: try:
bits = self.get_tensor("gptq_bits").item() bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item() groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
except (SafetensorError, RuntimeError) as e: except (SafetensorError, RuntimeError) as e:
try: try:
bits = self.gptq_bits bits = self.gptq_bits
groupsize = self.gptq_groupsize groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
except Exception: except Exception:
raise e raise e
return bits, groupsize return bits, groupsize, desc_act
def _set_gptq_params(self, model_id, revision): def _set_gptq_params(self, model_id, revision):
filename = "config.json" filename = "config.json"
...@@ -340,6 +351,7 @@ class Weights: ...@@ -340,6 +351,7 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"] self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception: except Exception:
filename = "quantize_config.json" filename = "quantize_config.json"
try: try:
...@@ -353,6 +365,7 @@ class Weights: ...@@ -353,6 +365,7 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["bits"] self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"] self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
filename = "quant_config.json" filename = "quant_config.json"
try: try:
...@@ -366,5 +379,6 @@ class Weights: ...@@ -366,5 +379,6 @@ class Weights:
data = json.load(f) data = json.load(f)
self.gptq_bits = data["w_bit"] self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"] self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
except Exception: except Exception:
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment