Unverified Commit 01a5d18a authored by CHU Tianxiang's avatar CHU Tianxiang Committed by GitHub
Browse files

Add Support for 2/3/8-bit GPTQ Quantization Models (#2330)

parent 929b4f29
...@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm( ...@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx, torch::Tensor b_g_idx,
bool use_exllama); bool use_exllama,
int bit);
void gptq_shuffle( void gptq_shuffle(
torch::Tensor q_weight, torch::Tensor q_weight,
torch::Tensor q_perm); torch::Tensor q_perm,
int bit);
void moe_align_block_size( void moe_align_block_size(
torch::Tensor topk_ids, torch::Tensor topk_ids,
......
...@@ -146,6 +146,129 @@ public: ...@@ -146,6 +146,129 @@ public:
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
}; };
class MatrixView_q2_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq ...@@ -13,7 +13,10 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "compat.cuh" #include "compat.cuh"
#include "matrix_view.cuh" #include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
#include "qdq_4.cuh" #include "qdq_4.cuh"
#include "qdq_8.cuh"
namespace vllm { namespace vllm {
namespace gptq { namespace gptq {
...@@ -22,6 +25,7 @@ namespace gptq { ...@@ -22,6 +25,7 @@ namespace gptq {
#define BLOCK_M_SIZE_MAX 8 #define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50 #define MAX_Q_GEMM_ROWS 50
#define MAX_Q_GEMM_ROWS_8BIT 24
#define MAX_ALT_GEMM_ROWS 8 #define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32 #define THREADS_X 32
#define THREADS_Y 32 #define THREADS_Y 32
...@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) ...@@ -75,6 +79,106 @@ __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
return __half2float(__low2half(result)) + __half2float(__high2half(result)); return __half2float(__low2half(result)) + __half2float(__high2half(result));
} }
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
{
// Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
float result = {};
#pragma unroll
for (int i = 0; i < 4; i++)
{
half2 w01 = dq[i];
float w0 = __low2float(w01);
float w1 = __high2float(w01);
float x0 = __half2float(*a_ptr++);
float x1 = __half2float(*a_ptr++);
result = fma(w0, x0, result);
result = fma(w1, x1, result);
}
float qs = __half2float(qs_h);
result *= qs;
half result_h = __float2half_rn(result);
return __hadd(result_h, g_result);
}
__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
half result_h = __hadd(__low2half(result), __high2half(result));
return __hfma(result_h, qs_h, g_result);
}
typedef void (*fp_gemm_half_q_half_gptq_kernel) typedef void (*fp_gemm_half_q_half_gptq_kernel)
( (
const half*, const half*,
...@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel) ...@@ -89,8 +193,9 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)
const int* const int*
); );
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel __global__ void gemm_half_q_half_gptq_4bit_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
...@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel ...@@ -231,80 +336,794 @@ __global__ void gemm_half_q_half_gptq_kernel
} }
} }
template <bool first_block, int m_count>
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) __global__ void gemm_half_q_half_gptq_2bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{ {
#if BLOCK_M_SIZE_MAX >= 1 MatrixView_half a_(a, size_m, size_k);
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>; MatrixView_half_rw c_(c, size_m, size_n);
#endif MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
#if BLOCK_M_SIZE_MAX >= 2 MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
int t = threadIdx.x;
void gemm_half_q_half_cuda_part // Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
b_ptr += size_n;
a_ptr += 16;
}
k += 16;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_3bit_kernel
( (
const half* a, const half* __restrict__ a,
const uint32_t* b_q_weight, const uint32_t* __restrict__ b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* b_gptq_scales, const half* __restrict__ b_gptq_scales,
const int* b_q_perm, half* __restrict__ c,
half* c, const int size_m,
int size_m, const int size_n,
int size_n, const int size_k,
int size_k, const int groups,
int m_count, const int* __restrict__ b_q_perm
int groups
) )
{ {
dim3 blockDim, gridDim; MatrixView_half a_(a, size_m, size_k);
blockDim.x = BLOCK_KN_SIZE; MatrixView_half_rw c_(c, size_m, size_n);
blockDim.y = 1; MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
blockDim.z = 1; MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); int t = threadIdx.x;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Block
kernel<<<gridDim, blockDim, 0, stream>>> int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
( int offset_m = blockIdx.y * m_count;
a, int offset_k = blockIdx.z * BLOCK_KN_SIZE;
b_q_weight,
b_gptq_qzeros, int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
b_gptq_scales, int end_m = min(offset_m + m_count, size_m);
c, int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
size_m,
size_n, int n = offset_n + t * 4;
size_k,
groups, // Preload block_a
b_q_perm __shared__ half block_a[m_count][BLOCK_KN_SIZE];
);
} if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / 32 * 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 32;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_8bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int* __restrict__ b_q_perm
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (blockIdx.z == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
half scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4(scales, group, n);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]);
block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]);
block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]);
block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]);
}
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(
bool first_block, const int m_count, const int bit)
{
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel<true, M_COUNT>; \
}
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL(1);
#endif
#if BLOCK_M_SIZE_MAX >= 2
SELECT_KERNEL(2);
#endif
#if BLOCK_M_SIZE_MAX >= 3
SELECT_KERNEL(3);
#endif
#if BLOCK_M_SIZE_MAX >= 4
SELECT_KERNEL(4);
#endif
#if BLOCK_M_SIZE_MAX >= 5
SELECT_KERNEL(5);
#endif
#if BLOCK_M_SIZE_MAX >= 6
SELECT_KERNEL(6);
#endif
#if BLOCK_M_SIZE_MAX >= 7
SELECT_KERNEL(7);
#endif
#if BLOCK_M_SIZE_MAX >= 8
SELECT_KERNEL(8);
#endif
return NULL;
}
void gemm_half_q_half_cuda_part
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales,
const int* b_q_perm,
half* c,
int size_m,
int size_n,
int size_k,
int m_count,
int groups,
int bit
)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
c,
size_m,
size_n,
size_k,
groups,
b_q_perm
);
}
__global__ void reconstruct_exllama_8bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 8);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 4; p++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1);
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_4bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_3bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
const int size_k,
const int size_n,
const int groups,
half* __restrict__ b
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ int perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int groupsize = size_k / groups;
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / 32* 3;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 1; p++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1);
if (b_q_perm)
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 16; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
__global__ void reconstruct_exllama_kernel __global__ void reconstruct_exllama_2bit_kernel
( (
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm, const int* __restrict__ b_q_perm,
...@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel ...@@ -317,7 +1136,7 @@ __global__ void reconstruct_exllama_kernel
) )
{ {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y; int offset_k = BLOCK_KN_SIZE * blockIdx.y;
...@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel ...@@ -345,21 +1164,15 @@ __global__ void reconstruct_exllama_kernel
int nextgroup = offset_k + groupsize; int nextgroup = offset_k + groupsize;
// b offset // b offset
int qk = offset_k / (32 / 4); int qk = offset_k / (32 / 2);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n; const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale // Initial zeros/scale
int zeros[4]; int zeros[4];
half2 scales[4]; half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
__syncthreads(); __syncthreads();
...@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel ...@@ -374,28 +1187,24 @@ __global__ void reconstruct_exllama_kernel
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
} }
for (int p = 0; p < 4; p++) for (int p = 0; p < 2; p++)
{ {
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr; const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); half2 dq[4][8];
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
b_ptr += size_n; b_ptr += size_n;
//half* dqh = (half*)dq; //half* dqh = (half*)dq;
if (b_q_perm) if (b_q_perm)
{ {
for (int j = 0; j < 4; j++) for (int j = 0; j < 8; j++)
{ {
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
...@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel ...@@ -404,7 +1213,7 @@ __global__ void reconstruct_exllama_kernel
} }
else else
{ {
for (int j = 0; j < 4; j++) for (int j = 0; j < 8; j++)
{ {
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
...@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel ...@@ -416,7 +1225,6 @@ __global__ void reconstruct_exllama_kernel
} }
} }
void reconstruct_exllama void reconstruct_exllama
( (
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
...@@ -426,7 +1234,8 @@ void reconstruct_exllama ...@@ -426,7 +1234,8 @@ void reconstruct_exllama
half* out, half* out,
int height, int height,
int width, int width,
int groups int groups,
int bit
) )
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
...@@ -435,6 +1244,15 @@ void reconstruct_exllama ...@@ -435,6 +1244,15 @@ void reconstruct_exllama
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel;
if (bit == 2) {
reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel;
} else if (bit == 3) {
reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel;
} else if (bit == 8) {
reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>> reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
( (
...@@ -450,7 +1268,7 @@ void reconstruct_exllama ...@@ -450,7 +1268,7 @@ void reconstruct_exllama
} }
__global__ void gemm_half_q_half_alt_kernel( __global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec, const half2* __restrict__ vec,
const uint32_t* __restrict__ mat, const uint32_t* __restrict__ mat,
half* __restrict__ mul, half* __restrict__ mul,
...@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel( ...@@ -548,6 +1366,95 @@ __global__ void gemm_half_q_half_alt_kernel(
} }
__global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec,
const uint32_t* __restrict__ mat,
half* __restrict__ mul,
const half* __restrict__ scales,
const uint32_t* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int height,
int width
)
{
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
blockvec[m][threadIdx.x] =
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
threadIdx.x];
}
}
if (blockIdx.z == 0)
{
for (int m = 0; m < b_end; m++)
mul[(b + m) * width + w] = __int2half_rn(0);
}
__syncthreads();
int i = width * h + w;
int g_h = h * 4;
int k = 0;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
half2 res2;
half res[BLOCK_M_SIZE_MAX] = {};
unsigned int tmp;
while (k < h_end) {
tmp = mat[i];
half2 scales_tmp[2];
half2 zeros_tmp[2];
for (int tmp_k = 0; tmp_k < 2; tmp_k++) {
int g = g_idx[g_h + (k + tmp_k) * 2];
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
half scale_f = scales[g * width + w];
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))
);
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
for (int m = 0; m < b_end; m++) {
#ifndef USE_ROCM
res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif
half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF));
res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF));
res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif
}
i += width;
k += 2;
}
for (int m = 0; m < b_end; m++) {
atomicAdd(&mul[(b + m) * width + w], res[m]);
}
}
void gemm_half_q_half_alt void gemm_half_q_half_alt
( (
const half* a, const half* a,
...@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt ...@@ -558,7 +1465,8 @@ void gemm_half_q_half_alt
half* c, half* c,
int size_m, int size_m,
int size_n, int size_n,
int size_k int size_k,
int bit
) )
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
...@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt ...@@ -569,8 +1477,13 @@ void gemm_half_q_half_alt
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
auto kernel = gemm_half_q_half_alt_4bit_kernel;
if (bit == 8) {
kernel = gemm_half_q_half_alt_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
(const half2*) a, (const half2*) a,
b_q_weight, b_q_weight,
...@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt ...@@ -579,12 +1492,12 @@ void gemm_half_q_half_alt
b_gptq_qzeros, b_gptq_qzeros,
b_g_idx, b_g_idx,
size_m, size_m,
size_k / 8, size_k / 32 * bit,
size_n size_n
); );
} }
template<class T, int bit>
__global__ void reconstruct_gptq_kernel __global__ void reconstruct_gptq_kernel
( (
const uint32_t* __restrict__ w, const uint32_t* __restrict__ w,
...@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel ...@@ -600,30 +1513,79 @@ __global__ void reconstruct_gptq_kernel
// Start of block // Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 8; int row = blockIdx.y * 32 / bit;
if (column >= width) return; if (column >= width) return;
// Views // Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width); MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width); MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q4_row w_zeros_(w_zeros, group, width); T w_zeros_(w_zeros, group, width);
uint32_t w_read = w_.item_uint32_t(row, column); uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column); half* out_ptr = out_.item_ptr(row, column);
#pragma unroll #pragma unroll
for (int s = 0; s < 32; s += 4) for (int s = 0; s < 32; s += bit)
{ {
int group = g_idx[row + s / 4]; int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column); half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_zero = w_zeros_.item(group, column) + 1;
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width; *out_ptr = w_item; out_ptr += out_.width;
} }
} }
__global__ void reconstruct_gptq_3bit_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height,
const int width,
const int group,
half* __restrict__ out
)
{
// Start of block
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
int row = blockIdx.y * 32;
if (column >= width) return;
// Views
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width);
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int i = 0; i < 32; i += 1)
{
int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
int w_item;
if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
} else if (i == 21) {
w_item = (w2 >> 31) | ((w3 << 1) & 0x6);
} else if (i < 10) {
w_item = ((w1 >> (i * 3)) & 0x7);
} else if (i < 21) {
w_item = ((w2 >> (i * 3 - 32)) & 0x7);
} else {
w_item = ((w3 >> (i * 3 - 64)) & 0x7);
}
*out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale);
out_ptr += out_.width;
}
}
void reconstruct_gptq void reconstruct_gptq
( (
...@@ -634,16 +1596,28 @@ void reconstruct_gptq ...@@ -634,16 +1596,28 @@ void reconstruct_gptq
half* out, half* out,
int height, int height,
int width, int width,
int groups int groups,
int bit
) )
{ {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
gridDim.y = DIVIDE(height, 8); gridDim.y = DIVIDE(height, 32 / bit);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto kernel = reconstruct_gptq_kernel<MatrixView_q4_row, 4>;
if (bit == 2) {
kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
} else if (bit == 8) {
kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
} else if (bit == 3) {
kernel = reconstruct_gptq_3bit_kernel;
gridDim.y = DIVIDE(height, 32);
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
b_q_weight, b_q_weight,
b_gptq_scales, b_gptq_scales,
...@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda ...@@ -671,19 +1645,27 @@ void gemm_half_q_half_cuda
int size_n, int size_n,
int size_k, int size_k,
int groups, int groups,
bool use_exllama bool use_exllama,
int bit
) )
{ {
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { bool use_reconstruct;
if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS));
} else {
// The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now.
use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS);
}
if (use_reconstruct) {
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) { if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
size_k, size_n, groups); size_k, size_n, groups, bit);
} }
else else
{ {
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups); temp_dq, size_k, size_n, groups, bit);
} }
const half alpha = __float2half(1.0f); const half alpha = __float2half(1.0f);
...@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda ...@@ -707,7 +1689,7 @@ void gemm_half_q_half_cuda
{ {
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups); groups, bit);
} }
if (last_chunk_size) if (last_chunk_size)
...@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda ...@@ -715,18 +1697,17 @@ void gemm_half_q_half_cuda
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
b_gptq_scales, b_g_idx, c + last_chunk * size_n, b_gptq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size, last_chunk_size, size_n, size_k, last_chunk_size,
groups); groups, bit);
} }
} }
else else
{ {
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k); c, size_m, size_n, size_k, bit);
} }
} }
__global__ void shuffle_4bit_kernel
__global__ void shuffle_kernel
( (
uint32_t* __restrict__ b_q_weight, uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_k,
...@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel ...@@ -740,13 +1721,53 @@ __global__ void shuffle_kernel
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
} }
__global__ void shuffle_8bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
}
__global__ void shuffle_2bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
__global__ void shuffle_3bit_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
}
__global__ void make_sequential_kernel __global__ void make_sequential_4bit_kernel
( (
const uint32_t* __restrict__ w, const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new, uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm, const int* __restrict__ q_perm,
const int w_height,
const int w_width const int w_width
) )
{ {
...@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel ...@@ -778,37 +1799,204 @@ __global__ void make_sequential_kernel
w_new2[w_new2_row * w2_stride + w2_column] = dst; w_new2[w_new2_row * w2_stride + w2_column] = dst;
} }
__global__ void make_sequential_2bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 4;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 16; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 4;
int w2_subrow = source_row & 0x0f;
int w2_row_shift = w2_subrow << 1;
int wnew2_row_shift = i << 1;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000300000003;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
__global__ void make_sequential_3bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w_column >= w_width) return;
int w_new_row = blockIdx.y * 3;
int q_perm_idx = blockIdx.y << 5;
uint32_t dst[3] = {0, 0, 0};
#pragma unroll
for (int i = 0; i < 32; i++)
{
int source_row = q_perm[q_perm_idx++];
int z_w = (source_row / 32) * 3;
int z_mod = source_row % 32;
int z_bit;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
uint64_t src;
if (z_mod == 10) {
src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4);
} else if (z_mod == 21){
src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6);
} else {
src = w[z_w * w_width + w_column];
src >>= z_bit;
src &= 0x07;
}
z_w = 0;
if (i != 10){
if (i != 21){
z_bit = i;
if (z_bit > 21){
z_bit *= 3;
z_bit -= 64;
z_w += 2;
} else if (z_bit > 10){
z_bit *= 3;
z_bit -= 32;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
if (i == 10) {
dst[z_w] |= (src & 0x03) << 30;
dst[z_w + 1] |= ((src & 0x4) >> 2);
} else if (i == 21) {
dst[z_w] |= (src & 0x01) << 31;
dst[z_w + 1] |= ((src & 0x6) >> 1);
} else {
dst[z_w] |= (src << z_bit);
}
}
w_new[w_new_row * w_width + w_column] = dst[0];
w_new[(w_new_row + 1) * w_width + w_column] = dst[1];
w_new[(w_new_row + 2) * w_width + w_column] = dst[2];
}
__global__ void make_sequential_8bit_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const int* __restrict__ q_perm,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 2;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 2;
int w2_subrow = source_row & 0x03;
int w2_row_shift = w2_subrow << 3;
int wnew2_row_shift = i << 3;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x000000ff000000ff;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void shuffle_exllama_weight void shuffle_exllama_weight
( (
uint32_t* q_weight, uint32_t* q_weight,
int* q_perm, int* q_perm,
int height, int height,
int width int width,
int bit
) )
{ {
if (q_perm) if (q_perm)
{ {
uint32_t* new_qweight = NULL; uint32_t* new_qweight = NULL;
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t)); cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t));
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = THREADS_X; blockDim.x = THREADS_X;
blockDim.y = 1; blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8; gridDim.y = height / 32 * bit;
auto kernel = make_sequential_4bit_kernel;
if (bit == 2) {
kernel = make_sequential_2bit_kernel;
} else if (bit == 3) {
kernel = make_sequential_3bit_kernel;
gridDim.y = height / 32;
} else if (bit == 8) {
kernel = make_sequential_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
q_weight, q_weight,
new_qweight, new_qweight,
q_perm, q_perm,
height / 8,
width width
); );
// Replace qweights // Replace qweights
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup // Cleanup
cudaDeviceSynchronize(); cudaDeviceSynchronize();
cudaFree(new_qweight); cudaFree(new_qweight);
...@@ -818,6 +2006,14 @@ void shuffle_exllama_weight ...@@ -818,6 +2006,14 @@ void shuffle_exllama_weight
blockDim.y = 1; blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X); gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1; gridDim.y = 1;
auto shuffle_kernel = shuffle_4bit_kernel;
if (bit == 2) {
shuffle_kernel = shuffle_2bit_kernel;
} else if (bit == 3) {
shuffle_kernel = shuffle_3bit_kernel;
} else if (bit == 8) {
shuffle_kernel = shuffle_8bit_kernel;
}
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width); shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
} }
...@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm ...@@ -832,13 +2028,14 @@ torch::Tensor gptq_gemm
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx, torch::Tensor b_g_idx,
bool use_exllama bool use_exllama,
int bit
) )
{ {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
vllm::gptq::gemm_half_q_half_cuda vllm::gptq::gemm_half_q_half_cuda
( (
...@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm ...@@ -854,7 +2051,8 @@ torch::Tensor gptq_gemm
c.size(1), // n c.size(1), // n
a.size(1), // k a.size(1), // k
b_gptq_qzeros.size(0), // group number b_gptq_qzeros.size(0), // group number
use_exllama use_exllama,
bit
); );
return c; return c;
} }
...@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm ...@@ -862,14 +2060,16 @@ torch::Tensor gptq_gemm
void gptq_shuffle void gptq_shuffle
( (
torch::Tensor q_weight, torch::Tensor q_weight,
torch::Tensor q_perm torch::Tensor q_perm,
int bit
) )
{ {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(), (uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 8, q_weight.size(0) * 32 / bit,
q_weight.size(1) q_weight.size(1),
bit
); );
} }
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace vllm
#endif
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace vllm
#endif
...@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8 ...@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
( (
const uint32_t q_0, const uint32_t q_0,
half2 (&dq)[4], half2 (&dq)[4],
int stride int stride,
const uint32_t zero
) )
{ {
const uint32_t c0 = 0x64006400; const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f); const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_); const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f); const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_, z1_); const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __halves2half2(z16_, z16_); const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0; uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
...@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq ...@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm
#else
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
} // namespace gptq
} // namespace vllm
#endif #endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace vllm
#endif
import enum import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fractions import Fraction
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig): ...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits self.pack_factor = Fraction(32, self.weight_bits)
# exllama kernel v1 only supports 4 bit if self.weight_bits not in [2, 3, 4, 8]:
if self.weight_bits != 4:
raise ValueError( raise ValueError(
"Currently, only 4-bit weight quantization is supported for " "Currently, only 2/3/4/8-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.") f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"], output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"], weights["qzeros"], weights["scales"],
weights["g_idx"], weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY) weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.reshape(out_shape) return output.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment