Commit cb74e4ee authored by ilyas@huggingface.co's avatar ilyas@huggingface.co
Browse files

added exllama kernels

parent a22d67a3
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q_matrix.cuh"
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
);
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
);
#endif
\ No newline at end of file
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const int,
const int,
const int,
const int,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
block_a_ptr[t] = a0;
}
}
// Clear
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
qscales[0]++;
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
}
// a, b offset
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
// Column result
float block_c[m_count][4] = {};
// Dequantize groups
int k = offset_k;
while (k < rows_8 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[5];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
// Accumulate column sums in c
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
return NULL;
}
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hadd2(result, g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
// __syncthreads();
// Column result
float block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
#include "q_matrix.cuh"
#include "matrix_view.cuh"
#include "util.cuh"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
// QMatrix constructor
QMatrix::QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
) :
device(_device),
height(_height),
width(_width),
groups(_groups),
temp_dq(_temp_dq)
{
cudaSetDevice(device);
failed = false;
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
// Create group map
rows_8 = 0;
rows_6 = 0;
rows_5 = 0;
rows_4 = 0;
rows_3 = 0;
rows_2 = 0;
if (!is_gptq)
{
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
}
free(cpu_q_groups);
rows_6 += rows_8;
rows_5 += rows_6;
rows_4 += rows_5;
rows_3 += rows_4;
rows_2 += rows_3;
}
else
{
rows_4 = height;
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx)
{
if (!make_sequential(_gptq_g_idx))
{
failed = true;
//printf("FAIL\n");
return;
}
}
}
// Shuffle quantized data
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()
{
}
// Reconstruct b[k,n] (GPTQ)
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_4
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ uint16_t perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
// Reconstruct b[k,n]
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
// Preload remapping table
int t = threadIdx.x;
__shared__ uint16_t perm[BLOCK_KN_SIZE];
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
// Column
int n = offset_n + t;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
int lk = 0;
__syncthreads();
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
dequant_8bit_8(q_0, q_1, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
uint32_t q_3 = *b_ptr; b_ptr += size_n;
uint32_t q_4 = *b_ptr; b_ptr += size_n;
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_4bit_8(q_0, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_2bit_16(q_0, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
}
void QMatrix::reconstruct(half* out)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
height,
width,
groupsize,
groups,
out,
rows_8,
rows_6,
rows_5,
rows_4,
rows_3,
rows_2
);
}
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_gptq_qzeros,
cuda_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
groups,
out,
rows_4
);
}
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint16_t* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) {
cudaError_t cuda_status = cudaGetLastError(); // Clear error
return false;
}
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Reduce to uint16_t
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
// Move to CUDA
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_new_qweight,
cuda_q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
return true;
}
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class QMatrix
{
public:
int device;
bool is_gptq;
int height;
int width;
int groups;
int groupsize;
int rows_8;
int rows_6;
int rows_5;
int rows_4;
int rows_3;
int rows_2;
uint32_t* cuda_q_weight = NULL;
uint16_t* cuda_q_perm = NULL;
uint16_t* cuda_q_invperm = NULL;
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
half* temp_dq;
bool failed;
QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
);
~QMatrix();
void reconstruct(half* out);
bool make_sequential(const uint32_t* cpu_g_idx);
private:
};
#endif
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z4 = __halves2half2(z4_, z4_);
const half2 z16 = __halves2half2(z16_, z16_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
#else
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
uint32_t qd = q[3 * stride];
uint32_t qe = q[4 * stride];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t qf = qe >> 22;
qe <<= 8;
qe |= qd >> 24;
qd <<= 6;
qd |= qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
uint32_t zd = 0;
uint32_t ze = 0;
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za |= ((qf & 0x001) >> 0) << 15;
zb |= ((qf & 0x002) >> 1) << 15;
zc |= ((qf & 0x004) >> 2) << 15;
zd |= ((qf & 0x008) >> 3) << 15;
ze |= ((qf & 0x010) >> 4) << 15;
za |= ((qf & 0x020) >> 5) << 31;
zb |= ((qf & 0x040) >> 6) << 31;
zc |= ((qf & 0x080) >> 7) << 31;
zd |= ((qf & 0x100) >> 8) << 31;
ze |= ((qf & 0x200) >> 9) << 31;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
q[3 * stride] = zd;
q[4 * stride] = ze;
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y32_ = __float2half_rn(1.0f / 32.0f);
const half2 y32 = __halves2half2(y32_, y32_);
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z32 = __halves2half2(z32_, z32_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
uint32_t qd = q_3;
uint32_t qe = q_4;
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
qa >>= 10;
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
qa >>= 5;
qa &= 0x00010001;
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
qb >>= 10;
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
qb >>= 4;
qb &= 0x00020002;
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
qc >>= 10;
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
qc >>= 3;
qc &= 0x00040004;
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
qd >>= 10;
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
qd >>= 2;
qd &= 0x00080008;
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
qe >>= 10;
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
qe >>= 1;
qe &= 0x00100010;
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hadd2( q3.as_half2, z1);
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hadd2( q6.as_half2, z1);
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
dq[ 8] = __hadd2( q8.as_half2, z1);
dq[ 9] = __hadd2( q9.as_half2, z1);
dq[10] = __hfma2(q10.as_half2, y32, z32);
dq[11] = __hadd2(q11.as_half2, z1);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y32, z32);
dq[14] = __hadd2(q14.as_half2, z1);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_6bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_6bit_16
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
#endif
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
{
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
qs_h = __hmul(qs_h, qs_h);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t make_q_matrix
(
torch::Tensor q_weight,
torch::Tensor q_perm,
torch::Tensor q_invperm,
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
torch::Tensor temp_dq
)
{
TORCH_CHECK_DTYPE(q_weight, kInt);
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
int device = q_weight.device().index();
int width = q_weight.size(1);
int groups;
int height;
if (!q_scale.device().is_meta())
{
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
groups = q_scale.size(0);
height = q_invperm.size(0);
}
else
{
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
groups = gptq_qzeros.size(0);
height = q_weight.size(0) * 8;
}
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
QMatrix* m = new QMatrix
(
device,
height,
width,
groups,
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
return reinterpret_cast<uintptr_t> (m);
}
void gemm_half_q_half
(
torch::Tensor a,
uintptr_t b,
torch::Tensor c,
bool force_cuda
)
{
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
TORCH_CHECK_DTYPE(a, kHalf);
TORCH_CHECK_DTYPE(c, kHalf);
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
qm,
(half*) c.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
true,
NULL,
force_cuda
);
}
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}
......@@ -12,7 +12,9 @@ PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
if not PYPI_BUILD:
try:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3]
CUDA_VERSION = "".join(
os.environ.get("CUDA_VERSION", torch.version.cuda).split(".")
)[:3]
AUTOAWQ_KERNELS_VERSION += f"+cu{CUDA_VERSION}"
except Exception as ex:
raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ")
......@@ -24,7 +26,9 @@ common_setup_kwargs = {
"license": "MIT",
"python_requires": ">=3.8.0",
"description": "AutoAWQ Kernels implements the AWQ kernels.",
"long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
"long_description": (Path(__file__).parent / "README.md").read_text(
encoding="UTF-8"
),
"long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ_kernels",
"keywords": ["awq", "autoawq", "quantization", "transformers"],
......@@ -39,17 +43,20 @@ common_setup_kwargs = {
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
]
],
}
requirements = [
"torch>=2.0.1",
]
def get_include_dirs():
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
conda_cuda_include_dir = os.path.join(
get_python_lib(), "nvidia/cuda_runtime/include"
)
if os.path.isdir(conda_cuda_include_dir):
include_dirs.append(conda_cuda_include_dir)
this_dir = os.path.dirname(os.path.abspath(__file__))
......@@ -57,18 +64,24 @@ def get_include_dirs():
return include_dirs
def get_generator_flag():
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
if os.path.exists(
os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")
):
generator_flag = ["-DOLD_GENERATOR_PATH"]
return generator_flag
def check_dependencies():
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
f"Cannot find CUDA_HOME. CUDA must be available to build the package."
)
def get_compute_capabilities():
# Collect the compute capabilities of all available GPUs.
......@@ -77,7 +90,9 @@ def get_compute_capabilities():
cc = major * 10 + minor
if cc < 75:
raise RuntimeError("GPUs with compute capability less than 7.5 are not supported.")
raise RuntimeError(
"GPUs with compute capability less than 7.5 are not supported."
)
# figure out compute capability
compute_capabilities = {75, 80, 86, 89, 90}
......@@ -88,6 +103,7 @@ def get_compute_capabilities():
return capability_flags
check_dependencies()
include_dirs = get_include_dirs()
generator_flags = get_generator_flag()
......@@ -98,14 +114,14 @@ if os.name == "nt":
# Relaxed args on Windows
if include_arch:
extra_compile_args={"nvcc": arch_flags}
extra_compile_args = {"nvcc": arch_flags}
else:
extra_compile_args={}
extra_compile_args = {}
else:
extra_compile_args={
extra_compile_args = {
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
"nvcc": [
"-O3",
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
......@@ -117,7 +133,9 @@ else:
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
] + arch_flags + generator_flags
]
+ arch_flags
+ generator_flags,
}
extensions = [
......@@ -128,10 +146,36 @@ extensions = [
"awq_ext/quantization/gemm_cuda_gen.cu",
"awq_ext/layernorm/layernorm.cu",
"awq_ext/position_embedding/pos_encoding_kernels.cu",
"awq_ext/quantization/gemv_cuda.cu"
], extra_compile_args=extra_compile_args
"awq_ext/quantization/gemv_cuda.cu",
],
extra_compile_args=extra_compile_args,
)
]
extensions.append(
CUDAExtension(
"exllama_kernels",
[
"awq_ext/exllama/exllama_ext.cpp",
"awq_ext/exllama/cuda_buffers.cu",
"awq_ext/exllama/cuda_func/column_remap.cu",
"awq_ext/exllama/cuda_func/q4_matmul.cu",
"awq_ext/exllama/cuda_func/q4_matrix.cu",
],
extra_compile_args=extra_compile_args,
)
)
extensions.append(
CUDAExtension(
"exllamav2_kernels",
[
"awq_ext/exllamav2/ext.cpp",
"awq_ext/exllamav2/cuda/q_matrix.cu",
"awq_ext/exllamav2/cuda/q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
)
if os.name != "nt":
extensions.append(
......@@ -140,14 +184,15 @@ if os.name != "nt":
[
"awq_ext/pybind_awq_ft.cpp",
"awq_ext/attention/ft_attention.cpp",
"awq_ext/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args
"awq_ext/attention/decoder_masked_multihead_attention.cu",
],
extra_compile_args=extra_compile_args,
)
)
additional_setup_kwargs = {
"ext_modules": extensions,
"cmdclass": {'build_ext': BuildExtension}
"cmdclass": {"build_ext": BuildExtension},
}
common_setup_kwargs.update(additional_setup_kwargs)
......@@ -156,5 +201,5 @@ setup(
packages=find_packages(),
install_requires=requirements,
include_dirs=include_dirs,
**common_setup_kwargs
**common_setup_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment