Unverified Commit 01a5d18a authored by CHU Tianxiang's avatar CHU Tianxiang Committed by GitHub
Browse files

Add Support for 2/3/8-bit GPTQ Quantization Models (#2330)

parent 929b4f29
......@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama);
bool use_exllama,
int bit);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);
torch::Tensor q_perm,
int bit);
void moe_align_block_size(
torch::Tensor topk_ids,
......
......@@ -146,6 +146,129 @@ public:
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
class MatrixView_q2_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq
} // namespace vllm
#endif
......@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "compat.cuh"
#include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
#include "qdq_4.cuh"
#include "qdq_8.cuh"
namespace vllm {
namespace gptq {
......@@ -22,6 +25,7 @@ namespace gptq {
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS_8BIT 24
#define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32
#define THREADS_Y 32
......@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
......@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int*
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
__global__ void gemm_half_q_half_gptq_4bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
......@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_2bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
void gemm_half_q_half_cuda_part
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
b_ptr += size_n;
a_ptr += 16;
}
k += 16;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_3bit_kernel
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
int t = threadIdx.x;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / 32 * 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 32;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_8bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(
bool first_block, const int m_count, const int bit)
{
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
}
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL(1);
#endif
#if BLOCK_M_SIZE_MAX >= 2
SELECT_KERNEL(2);
#endif
#if BLOCK_M_SIZE_MAX >= 3
SELECT_KERNEL(3);
#endif
#if BLOCK_M_SIZE_MAX >= 4
SELECT_KERNEL(4);
#endif
#if BLOCK_M_SIZE_MAX >= 5
SELECT_KERNEL(5);
#endif
#if BLOCK_M_SIZE_MAX >= 6
SELECT_KERNEL(6);
#endif
#if BLOCK_M_SIZE_MAX >= 7
SELECT_KERNEL(7);
#endif
#if BLOCK_M_SIZE_MAX >= 8
SELECT_KERNEL(8);
#endif
return NULL;
}
void gemm_half_q_half_cuda_part
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups,
int bit
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
__global__ void reconstruct_exllama_8bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 4; p++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_4bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_3bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / 32* 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 1; p++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
if (b_q_perm)
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_kernel
__global__ void reconstruct_exllama_2bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
......@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
......@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
......@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
for (int p = 0; p < 2; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
for (int j = 0; j < 8; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
......@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel
}
else
{
for (int j = 0; j < 4; j++)
for (int j = 0; j < 8; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
......@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel
}
}
void reconstruct_exllama
(
const uint32_t* b_q_weight,
......@@ -426,7 +1234,8 @@ void reconstruct_exllama
half* out,
int height,
int width,
int groups
int groups,
int bit
)
{
dim3 blockDim, gridDim;
......@@ -435,6 +1244,15 @@ void reconstruct_exllama
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel;
if (bit == 2) {
reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel;
} else if (bit == 3) {
reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel;
} else if (bit == 8) {
reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
(
......@@ -450,7 +1268,7 @@ void reconstruct_exllama
}
__global__ void gemm_half_q_half_alt_kernel(
__global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
......@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel(
}
__global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
const half* __restrict__ scales,
const uint32_t* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int height,
int width
)
{
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
blockvec[m][threadIdx.x] =
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
threadIdx.x];
}
}
if (blockIdx.z == 0)
{
for (int m = 0; m < b_end; m++)
mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads();
int i = width * h + w;
int g_h = h * 4;
int k = 0;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
half2 res2;
half res[BLOCK_M_SIZE_MAX] = {};
unsigned int tmp;
while (k < h_end) {
tmp = mat[i];
half2 scales_tmp[2];
half2 zeros_tmp[2];
for (int tmp_k = 0; tmp_k < 2; tmp_k++) {
int g = g_idx[g_h + (k + tmp_k) * 2];
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
half scale_f = scales[g * width + w];
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))
);
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
for (int m = 0; m < b_end; m++) {
#ifndef USE_ROCM
res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif
half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF));
res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF));
res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif
}
i += width;
k += 2;
}
for (int m = 0; m < b_end; m++) {
atomicAdd(&mul[(b + m) * width + w], res[m]);
}
}
void gemm_half_q_half_alt
(
const half* a,
......@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt
half* c,
int size_m,
int size_n,
int size_k
int size_k,
int bit
)
{
dim3 blockDim, gridDim;
......@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
auto kernel = gemm_half_q_half_alt_4bit_kernel;
if (bit == 8) {
kernel = gemm_half_q_half_alt_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half2*) a,
b_q_weight,
......@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt
b_gptq_qzeros,
b_g_idx,
size_m,
size_k / 8,
size_k / 32 * bit,
size_n
);
}
template<class T, int bit>
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ w,
......@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 8;
int row = blockIdx.y * 32 / bit;
if (column >= width) return;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q4_row w_zeros_(w_zeros, group, width);
T w_zeros_(w_zeros, group, width);
uint32_t w_read = w_.item_uint32_t(row, column);
uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
for (int s = 0; s < 32; s += bit)
{
int group = g_idx[row + s / 4];
int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
__global__ void reconstruct_gptq_3bit_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height,
const int width,
const int group,
half* __restrict__ out
)
{
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 32;
if (column >= width) return;
// Views
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width);
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int i = 0; i < 32; i += 1)
{
int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
int w_item;
if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
} else if (i == 21) {
w_item = (w2 >> 31) | ((w3 << 1) & 0x6);
} else if (i < 10) {
w_item = ((w1 >> (i * 3)) & 0x7);
} else if (i < 21) {
w_item = ((w2 >> (i * 3 - 32)) & 0x7);
} else {
w_item = ((w3 >> (i * 3 - 64)) & 0x7);
}
*out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale);
out_ptr += out_.width;
}
}
void reconstruct_gptq
(
......@@ -634,16 +1596,28 @@ void reconstruct_gptq
half* out,
int height,
int width,
int groups
int groups,
int bit
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.y = DIVIDE(height, 32 / bit);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto kernel = reconstruct_gptq_kernel<MatrixView_q4_row, 4>;
if (bit == 2) {
kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
} else if (bit == 8) {
kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
} else if (bit == 3) {
kernel = reconstruct_gptq_3bit_kernel;
gridDim.y = DIVIDE(height, 32);
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_gptq_scales,
......@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda
int size_n,
int size_k,
int groups,
bool use_exllama
bool use_exllama,
int bit
)
{
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
bool use_reconstruct;
if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
} else {
// The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now.
use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
}
if (use_reconstruct) {
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups);
size_k, size_n, groups, bit);
}
else
{
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups);
temp_dq, size_k, size_n, groups, bit);
}
const half alpha = __float2half(1.0f);
......@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda
{
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups);
groups, bit);
}
if (last_chunk_size)
......@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size,
groups);
groups, bit);
}
}
else
{
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k);
c, size_m, size_n, size_k, bit);
}
}
__global__ void shuffle_kernel
__global__ void shuffle_4bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
......@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
}
__global__ void shuffle_8bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
}
__global__ void shuffle_2bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
__global__ void shuffle_3bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
}
__global__ void make_sequential_kernel
__global__ void make_sequential_4bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
......@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
__global__ void make_sequential_2bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 4;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 16; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 4;
int w2_subrow = source_row & 0x0f;
int w2_row_shift = w2_subrow << 1;
int wnew2_row_shift = i << 1;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000300000003;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
__global__ void make_sequential_3bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w_column >= w_width) return;
int w_new_row = blockIdx.y * 3;
int q_perm_idx = blockIdx.y << 5;
uint32_t dst[3] = {0, 0, 0};
#pragma unroll
for (int i = 0; i < 32; i++)
{
int source_row = q_perm[q_perm_idx++];
int z_w = (source_row / 32) * 3;
int z_mod = source_row % 32;
int z_bit;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
uint64_t src;
if (z_mod == 10) {
src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4);
} else if (z_mod == 21){
src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6);
} else {
src = w[z_w * w_width + w_column];
src >>= z_bit;
src &= 0x07;
}
z_w = 0;
if (i != 10){
if (i != 21){
z_bit = i;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
if (i == 10) {
dst[z_w] |= (src & 0x03) << 30;
dst[z_w + 1] |= ((src & 0x4) >> 2);
} else if (i == 21) {
dst[z_w] |= (src & 0x01) << 31;
dst[z_w + 1] |= ((src & 0x6) >> 1);
} else {
dst[z_w] |= (src << z_bit);
}
}
w_new[w_new_row * w_width + w_column] = dst[0];
w_new[(w_new_row + 1) * w_width + w_column] = dst[1];
w_new[(w_new_row + 2) * w_width + w_column] = dst[2];
}
__global__ void make_sequential_8bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 2;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 2;
int w2_subrow = source_row & 0x03;
int w2_row_shift = w2_subrow << 3;
int wnew2_row_shift = i << 3;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x000000ff000000ff;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void shuffle_exllama_weight
(
uint32_t* q_weight,
int* q_perm,
int height,
int width
int width,
int bit
)
{
if (q_perm)
{
uint32_t* new_qweight = NULL;
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t));
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
gridDim.y = height / 32 * bit;
auto kernel = make_sequential_4bit_kernel;
if (bit == 2) {
kernel = make_sequential_2bit_kernel;
} else if (bit == 3) {
kernel = make_sequential_3bit_kernel;
gridDim.y = height / 32;
} else if (bit == 8) {
kernel = make_sequential_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
kernel<<<gridDim, blockDim, 0, stream>>>
(
q_weight,
new_qweight,
q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(new_qweight);
......@@ -818,6 +2006,14 @@ void shuffle_exllama_weight
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
auto shuffle_kernel = shuffle_4bit_kernel;
if (bit == 2) {
shuffle_kernel = shuffle_2bit_kernel;
} else if (bit == 3) {
shuffle_kernel = shuffle_3bit_kernel;
} else if (bit == 8) {
shuffle_kernel = shuffle_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}
......@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama
bool use_exllama,
int bit
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
vllm::gptq::gemm_half_q_half_cuda
(
......@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm
c.size(1), // n
a.size(1), // k
b_gptq_qzeros.size(0), // group number
use_exllama
use_exllama,
bit
);
return c;
}
......@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm
void gptq_shuffle
(
torch::Tensor q_weight,
torch::Tensor q_perm
torch::Tensor q_perm,
int bit
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 8,
q_weight.size(1)
q_weight.size(0) * 32 / bit,
q_weight.size(1),
bit
);
}
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace vllm
#endif
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace vllm
#endif
......@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
......@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
} // namespace gptq
} // namespace vllm
#else
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
} // namespace gptq
} // namespace vllm
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace vllm
#endif
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from fractions import Fraction
import torch
from torch.nn.parameter import Parameter
......@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits
# exllama kernel v1 only supports 4 bit
if self.weight_bits != 4:
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
"Currently, only 2/3/4/8-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
......@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
......@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY)
weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output = output + bias
return output.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment