Commit 6a583c2f authored by chenych's avatar chenych
Browse files

update dtk to 24.04.1 and modify README

parent 7d576a9a
#ifndef _q_gemm_cuh
#define _q_gemm_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q_matrix.cuh"
void gemm_half_q_half_cuda
(
cublasHandle_t cublas_handle,
const half* a,
QMatrix* b,
half* c,
int size_m,
int size_n,
int size_k,
bool clear = false,
half* reconstruct = NULL,
bool force_cuda = false
);
void clear_tensor_cuda
(
half* c,
int size_m,
int size_n
);
#endif
\ No newline at end of file
#include "compat.cuh"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
return fma(result_f, qs_f, g_result);
}
typedef void (*fp_gemm_half_q_half_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const int,
const int,
const int,
const int,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0 = a_ptr[b_q_perm[offset_k + t]];
block_a_ptr[t] = a0;
}
}
// Clear
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
// Preload scales
float scales[MAX_GROUPS_IN_BLOCK][4];
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
for (int g = 0; g < groups_in_block; g++)
{
int qscales[4];
b_q_scale_.item4(qscales, group + g, n);
qscales[0]++;
qscales[1]++;
qscales[2]++;
qscales[3]++;
float maxscale = __half2float(b_q_scale_max[group + g]);
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
}
// a, b offset
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int scales_idx = 0;
float qs_f0 = scales[scales_idx][0];
float qs_f1 = scales[scales_idx][1];
float qs_f2 = scales[scales_idx][2];
float qs_f3 = scales[scales_idx][3];
int nextgroup = offset_k + groupsize;
// Column result
float block_c[m_count][4] = {};
// Dequantize groups
int k = offset_k;
while (k < rows_8 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[2];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[5];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][4];
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 8;
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 1; j++)
{
int4 load_int4[3];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 32;
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup)
{
group++;
scales_idx++;
qs_f0 = scales[scales_idx][0];
qs_f1 = scales[scales_idx][1];
qs_f2 = scales[scales_idx][2];
qs_f3 = scales[scales_idx][3];
nextgroup += groupsize;
}
#pragma unroll
for (int j = 0; j < 2; j++)
{
int4 load_int4[1];
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
half2 dq[4][8];
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
}
a_ptr += 16;
}
k += 32;
}
// Accumulate column sums in c
for (int m = 0; m < m_count; m++)
{
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
#endif
return NULL;
}
#include "compat.cuh"
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __hadd2(result, g_result);
}
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
{
half2 result = {};
const half2* a2_ptr = (const half2*)a_ptr;
#pragma unroll
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
return __half2float(__low2half(result)) + __half2float(__high2half(result));
}
typedef void (*fp_gemm_half_q_half_gptq_kernel)
(
const half*,
const uint32_t*,
const uint32_t*,
const half*,
half*,
const int,
const int,
const int,
const int,
const int,
const uint16_t*,
const int,
const bool
);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
const int size_k,
const int groups,
const int groupsize,
const uint16_t* __restrict__ b_q_perm,
const int rows_4,
const bool clear
)
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int t = threadIdx.x;
// Block
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
int offset_m = blockIdx.y * m_count;
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
int end_m = min(offset_m + m_count, size_m);
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int n = offset_n + t * 4;
// Preload block_a
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
if (offset_k + t < end_k)
{
for (int m = 0; m < m_count; ++m)
{
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
half* block_a_ptr = block_a[m];
half a0;
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
else a0 = a_ptr[offset_k + t];
block_a_ptr[t] = a0;
}
}
// Zero output
if (n >= size_n) return;
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
{
for (int m = 0; m < m_count; m++)
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
}
__syncthreads();
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// a, b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
const half* a_ptr = &block_a[0][0];
int a_stride = BLOCK_KN_SIZE;
// Initial group
int zeros[4];
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
// __syncthreads();
// Column result
float block_c[m_count][4] = {};
// Dequantize and multiply
int k = offset_k;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
}
#pragma unroll
for (int j = 0; j < 4; j++)
{
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
{
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
}
b_ptr += size_n;
a_ptr += 8;
}
k += 32;
}
for (int m = 0; m < m_count; m++)
{
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
atomicAdd(out , result01);
atomicAdd(out + 1, result23);
}
}
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
{
#if BLOCK_M_SIZE_MAX >= 1
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
#endif
return NULL;
}
#include "q_matrix.cuh"
#include "matrix_view.cuh"
#include "util.cuh"
#include "quant/qdq_2.cuh"
#include "quant/qdq_3.cuh"
#include "quant/qdq_4.cuh"
#include "quant/qdq_5.cuh"
#include "quant/qdq_6.cuh"
#include "quant/qdq_8.cuh"
#define BLOCK_KN_SIZE 128
#define THREADS_X 32
#define THREADS_Y 32
// Shuffle quantized data on load
__global__ void shuffle_kernel
(
uint32_t* __restrict__ b_q_weight,
const int size_k,
const int size_n,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
int n = blockIdx.x * THREADS_X + threadIdx.x;
if (n >= size_n) return;
int k = 0;
uint32_t* b_ptr = b_q_weight + n;
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
}
// QMatrix constructor
QMatrix::QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
) :
device(_device),
height(_height),
width(_width),
groups(_groups),
temp_dq(_temp_dq)
{
cudaSetDevice(device);
failed = false;
cuda_q_weight = _q_weight;
cuda_q_perm = _q_perm;
cuda_q_invperm = _q_invperm;
cuda_q_scale = _q_scale;
cuda_q_scale_max = _q_scale_max;
cuda_q_groups = _q_groups;
cuda_gptq_qzeros = _gptq_qzeros;
cuda_gptq_scales = _gptq_scales;
is_gptq = (_gptq_qzeros != NULL);
groupsize = 1;
while (groupsize * groups < height) groupsize *= 2;
// Create group map
rows_8 = 0;
rows_6 = 0;
rows_5 = 0;
rows_4 = 0;
rows_3 = 0;
rows_2 = 0;
if (!is_gptq)
{
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
for (int i = 0; i < groups; i++)
{
int bits = cpu_q_groups[i * 2];
if (bits == 8) rows_8 += groupsize;
if (bits == 6) rows_6 += groupsize;
if (bits == 5) rows_5 += groupsize;
if (bits == 4) rows_4 += groupsize;
if (bits == 3) rows_3 += groupsize;
if (bits == 2) rows_2 += groupsize;
}
free(cpu_q_groups);
rows_6 += rows_8;
rows_5 += rows_6;
rows_4 += rows_5;
rows_3 += rows_4;
rows_2 += rows_3;
}
else
{
rows_4 = height;
rows_3 = height;
rows_2 = height;
if (_gptq_g_idx)
{
if (!make_sequential(_gptq_g_idx))
{
failed = true;
//printf("FAIL\n");
return;
}
}
}
// Shuffle quantized data
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}
QMatrix::~QMatrix()
{
}
// Reconstruct b[k,n] (GPTQ)
__global__ void reconstruct_gptq_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_4
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
// Preload remapping table
__shared__ uint16_t perm[BLOCK_KN_SIZE];
int t = threadIdx.x;
if (b_q_perm)
{
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
}
// Column
int n = offset_n + t * 4;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int nextgroup = offset_k + groupsize;
// b offset
int qk = offset_k / (32 / 4);
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
// Initial zeros/scale
int zeros[4];
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
__syncthreads();
int k = offset_k;
int lk = 0;
while (k < end_k)
{
if (k == nextgroup)
{
group++;
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
// Avoid zeros overflow with & 0x0f.
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0f, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0f, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0f, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0f, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)
{
half2 dq[4][4];
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
if (b_q_perm)
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
else
{
for (int j = 0; j < 4; j++)
{
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
}
}
}
k += 32;
}
}
// Reconstruct b[k,n]
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ b_q_weight,
const uint16_t* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_q_scale,
const half* __restrict__ b_q_scale_max,
//const uint16_t* __restrict__ b_q_groups,
const int size_k,
const int size_n,
const int groupsize,
const int groups,
half* __restrict__ b,
const int rows_8,
const int rows_6,
const int rows_5,
const int rows_4,
const int rows_3,
const int rows_2
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
// Preload remapping table
int t = threadIdx.x;
__shared__ uint16_t perm[BLOCK_KN_SIZE];
if (offset_k + t < size_k)
perm[t] = b_q_perm[offset_k + t];
// Column
int n = offset_n + t;
if (n >= size_n) return;
// Find initial group
int group = offset_k / groupsize;
int pre_rows_8 = min(rows_8, offset_k);
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
int qk = 0;
qk += pre_rows_8 / 32 * 8;
qk += pre_rows_6 / 32 * 6;
qk += pre_rows_5 / 32 * 5;
qk += pre_rows_4 / 32 * 4;
qk += pre_rows_3 / 32 * 3;
qk += pre_rows_2 / 32 * 2;
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
half2 qs_h2 = __halves2half2(qs_h, qs_h);
int nextgroup = offset_k + groupsize;
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
int k = offset_k;
int lk = 0;
__syncthreads();
while (k < rows_8 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
dequant_8bit_8(q_0, q_1, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_6 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_5 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
uint32_t q_3 = *b_ptr; b_ptr += size_n;
uint32_t q_4 = *b_ptr; b_ptr += size_n;
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_4 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 4; p++)
{
half2 dq[4];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_4bit_8(q_0, dq, size_n);
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_3 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 1; p++)
{
half2 dq[16];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
uint32_t q_1 = *b_ptr; b_ptr += size_n;
uint32_t q_2 = *b_ptr; b_ptr += size_n;
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
while (k < rows_2 && k < end_k)
{
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
for (int p = 0; p < 2; p++)
{
half2 dq[8];
uint32_t q_0 = *b_ptr; b_ptr += size_n;
dequant_2bit_16(q_0, dq, size_n);
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
half* dqh = (half*) dq;
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
}
k += 32;
}
}
void QMatrix::reconstruct(half* out)
{
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_q_scale,
cuda_q_scale_max,
//cuda_q_groups,
height,
width,
groupsize,
groups,
out,
rows_8,
rows_6,
rows_5,
rows_4,
rows_3,
rows_2
);
}
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_q_perm,
cuda_gptq_qzeros,
cuda_gptq_scales,
//const uint16_t* __restrict__ b_q_groups,
height,
width,
groupsize,
groups,
out,
rows_4
);
}
}
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint16_t* __restrict__ q_perm,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
if (w2_column >= w2_stride) return;
int w_new2_row = blockIdx.y;
int q_perm_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = q_perm[q_perm_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
if (err != cudaSuccess) {
cudaError_t cuda_status = cudaGetLastError(); // Clear error
return false;
}
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Reduce to uint16_t
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
// Move to CUDA
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 blockDim, gridDim;
blockDim.x = THREADS_X;
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
(
cuda_q_weight,
cuda_new_qweight,
cuda_q_perm,
height / 8,
width
);
// Replace qweights
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
return true;
}
#ifndef _q_matrix_cuh
#define _q_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define MAX_SUPERGROUPS 16
class QMatrix
{
public:
int device;
bool is_gptq;
int height;
int width;
int groups;
int groupsize;
int rows_8;
int rows_6;
int rows_5;
int rows_4;
int rows_3;
int rows_2;
uint32_t* cuda_q_weight = NULL;
uint16_t* cuda_q_perm = NULL;
uint16_t* cuda_q_invperm = NULL;
uint32_t* cuda_q_scale = NULL;
half* cuda_q_scale_max = NULL;
uint16_t* cuda_q_groups = NULL;
uint32_t* cuda_gptq_qzeros = NULL;
half* cuda_gptq_scales = NULL;
half* temp_dq;
bool failed;
QMatrix
(
const int _device,
const int _height,
const int _width,
const int _groups,
uint32_t* _q_weight,
uint16_t* _q_perm,
uint16_t* _q_invperm,
uint32_t* _q_scale,
half* _q_scale_max,
uint16_t* _q_groups,
uint32_t* _gptq_qzeros,
half* _gptq_scales,
uint32_t* _gptq_g_idx,
half* _temp_dq
);
~QMatrix();
void reconstruct(half* out);
bool make_sequential(const uint32_t* cpu_g_idx);
private:
};
#endif
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_2BIT == 1
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z4 = __halves2half2(z4_, z4_);
const half2 z16 = __halves2half2(z16_, z16_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
#else
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_3BIT == 1
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_4BIT == 1
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
#else
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
uint32_t qd = q[3 * stride];
uint32_t qe = q[4 * stride];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t qf = qe >> 22;
qe <<= 8;
qe |= qd >> 24;
qd <<= 6;
qd |= qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
uint32_t zd = 0;
uint32_t ze = 0;
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za |= ((qf & 0x001) >> 0) << 15;
zb |= ((qf & 0x002) >> 1) << 15;
zc |= ((qf & 0x004) >> 2) << 15;
zd |= ((qf & 0x008) >> 3) << 15;
ze |= ((qf & 0x010) >> 4) << 15;
za |= ((qf & 0x020) >> 5) << 31;
zb |= ((qf & 0x040) >> 6) << 31;
zc |= ((qf & 0x080) >> 7) << 31;
zd |= ((qf & 0x100) >> 8) << 31;
ze |= ((qf & 0x200) >> 9) << 31;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
q[3 * stride] = zd;
q[4 * stride] = ze;
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y32_ = __float2half_rn(1.0f / 32.0f);
const half2 y32 = __halves2half2(y32_, y32_);
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z32 = __halves2half2(z32_, z32_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
uint32_t qd = q_3;
uint32_t qe = q_4;
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
qa >>= 10;
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
qa >>= 5;
qa &= 0x00010001;
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
qb >>= 10;
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
qb >>= 4;
qb &= 0x00020002;
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
qc >>= 10;
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
qc >>= 3;
qc &= 0x00040004;
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
qd >>= 10;
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
qd >>= 2;
qd &= 0x00080008;
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
qe >>= 10;
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
qe >>= 1;
qe &= 0x00100010;
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hadd2( q3.as_half2, z1);
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hadd2( q6.as_half2, z1);
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
dq[ 8] = __hadd2( q8.as_half2, z1);
dq[ 9] = __hadd2( q9.as_half2, z1);
dq[10] = __hfma2(q10.as_half2, y32, z32);
dq[11] = __hadd2(q11.as_half2, z1);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y32, z32);
dq[14] = __hadd2(q14.as_half2, z1);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_6bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_6bit_16
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
\ No newline at end of file
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
#endif
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
{
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
qs_h = __hmul(qs_h, qs_h);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"CUDA error: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
// Some decluttering macros
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
// Quant matrix
uintptr_t make_q_matrix
(
torch::Tensor q_weight,
torch::Tensor q_perm,
torch::Tensor q_invperm,
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
torch::Tensor temp_dq
)
{
TORCH_CHECK_DTYPE(q_weight, kInt);
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
int device = q_weight.device().index();
int width = q_weight.size(1);
int groups;
int height;
if (!q_scale.device().is_meta())
{
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
groups = q_scale.size(0);
height = q_invperm.size(0);
}
else
{
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
groups = gptq_qzeros.size(0);
height = q_weight.size(0) * 8;
}
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
QMatrix* m = new QMatrix
(
device,
height,
width,
groups,
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
(half*) temp_dq.data_ptr()
);
return reinterpret_cast<uintptr_t> (m);
}
void gemm_half_q_half
(
torch::Tensor a,
uintptr_t b,
torch::Tensor c,
bool force_cuda
)
{
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
TORCH_CHECK_DTYPE(a, kHalf);
TORCH_CHECK_DTYPE(c, kHalf);
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
qm,
(half*) c.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
true,
NULL,
force_cuda
);
}
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}
/*
* Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/python.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include "marlin_cuda_kernel.cuh"
#include "marlin_repack.cuh"
const int ERR_PROB_SHAPE = 1;
const int ERR_KERN_SHAPE = 2;
void mul(
const torch::Tensor& A,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& s,
torch::Tensor& workspace,
int thread_k = -1,
int thread_n = -1,
int sms = -1,
int max_par = 8
) {
int prob_m = A.size(0);
int prob_n = C.size(1);
int prob_k = A.size(1);
int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0);
if (groupsize != -1 && groupsize * s.size(0) != prob_k)
AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups.");
if (workspace.numel() < prob_n / 128 * max_par)
AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, ".");
int dev = A.get_device();
int err = marlin_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
s.data_ptr(),
prob_m, prob_n, prob_k,
workspace.data_ptr(),
groupsize,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
max_par
);
if (err == ERR_PROB_SHAPE) {
AT_ERROR(
"Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")",
" not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "."
);
} else if (err == ERR_KERN_SHAPE) {
AT_ERROR(
"No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "."
);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mul", &mul, "Marlin FP16xINT4 matmul.");
m.def("gptq_repack", &gptq_repack, "Repack GPTQ checkpoints for Marlin.");
}
\ No newline at end of file
/*
* Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MARLIN_CUDA_KERNEL_CUH
#define MARLIN_CUDA_KERNEL_CUH
#include <cuda.h>
#include <cuda_fp16.h>
#include <assert.h>
#include <iostream>
#include "marlin_cuda_kernel.cuh"
constexpr int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core
// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee this.
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) {
return elems[i];
}
};
using I4 = Vec<int, 4>;
// Matrix fragments for tensor core instructions; their precise layout is documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that
// are not multiples of 16.
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
#else
assert(0);
#endif
}
// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for
// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need
// for inputs A and outputs C.
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)
);
#else
assert(0);
#endif
}
// Async copy fence.
__device__ inline void cp_async_fence() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n" ::);
#else
assert(0);
#endif
}
// Wait until at most `n` async copy stages are still pending.
template <int n>
__device__ inline void cp_async_wait() {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" :: "n"(n));
#else
assert(0);
#endif
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])
);
#else
assert(0);
#endif
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)
);
#else
assert(0);
#endif
}
// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to
// automatically recognize it in all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)
);
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.
// We mostly follow the strategy in the link below, with some small changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(
*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB)
);
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)
);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible globally.
asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
while (state != count);
}
__syncthreads();
#else
assert(0);
#endif
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.
asm volatile ("fence.acq_rel.gpu;\n");
asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
}
#else
assert(0);
#endif
}
template <
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple
// column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs
// for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as
// possible.
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);
// Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case
// where a stripe starts in the middle of group.
if (group_blocks != -1)
iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to top
// We can easily implement parallel problem execution by just remapping indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for synchronization.
auto init_slice = [&] () {
slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
slice_iters = 0;
if (slice_iters == 0)
return;
if (slice_row + slice_iters > k_tiles)
slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = ceildiv(k_tiles - col_off, iters);
if (col_off > 0)
slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0)
slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory
// We typically use `constexpr` to indicate that this value is a compile-time constant
constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads
constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile
constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile
int b_gl_stride = 16 * prob_n / 32;
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
constexpr int b_sh_wr_delta = threads;
constexpr int b_sh_rd_delta = threads;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_sh_stage = s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
int s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major
// layout in the former and in row-major in the latter case.
if (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
// Precompute which thread should not read memory in which iterations; this is needed if there are more threads than
// required for a certain tilesize or when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank
// conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of
// the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based
// on NSight-Compute) that each warp must also write a consecutive memory segment?
auto transform_a = [&] (int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory
// accesses are static, we simply precompute both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between
// subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4];
// Zero accumulators.
auto zero_accums = [&] () {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
// Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.
auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]
);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&] () {
// We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when
// it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.
auto fetch_to_registers = [&] (int k, int pipe) {
// It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a
// significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the
// compiler and correspondingly a noticable drop in performance.
if (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
}
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&] (int k) {
// We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j];
int b_quant_shift = b_quant >> 8;
FragB frag_b0 = dequant(b_quant);
// If there are no groups, we can just scale the final output once and can avoid doing so for each weight.
if (group_blocks != -1)
scale(frag_b0, frag_s[k % 2][j], 0);
FragB frag_b1 = dequant(b_quant_shift);
if (group_blocks != -1)
scale(frag_b1, frag_s[k % 2][j], 1);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n
// dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&] () {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
// Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,
// e.g., for two warps we write only once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
}
sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over
// the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto global_reduce = [&] (bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.
// To do this, we write out results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
// hence we also use async-copies even though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
reinterpret_cast<__half*>(&c_red)[j]
);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<__half*>(&c)[j] = __float2half(
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]
);
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,
// the reduction above is performed in fragment layout.
auto write_result = [&] () {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final global write patterns
auto write = [&] (int idx, float c0, float c1, FragS& s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks == -1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]);
((half2*) sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&] () {
#pragma unroll
for (int i = 0; i < stages - 1; i++)
fetch_to_shared(i, i, i < slice_iters);
zero_accums();
wait_for_stage();
fetch_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
};
start_pipes();
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are
// static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
pipe++;
wait_for_stage();
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0)
break;
}
a_gl_rd += a_gl_rd_delta_o * stages;
// Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most
// readable, other ways of writing the loop seemed to noticeably worse performance after compliation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before write-out
if (group_blocks == -1 && last) {
if (s_sh_wr_pred)
cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
thread_block_reduce();
if (group_blocks == -1 && last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
if (slice_count > 1) { // only globally reduce if there is more than one block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] -= b_gl_stride;
}
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
}
}
}
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more
// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.
const int THREADS = 256;
const int STAGES = 4; // 4 pipeline stages fit into shared memory
const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \
else if ( \
thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS \
) { \
cudaFuncSetAttribute( \
Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
SHARED_MEM \
); \
Marlin< \
THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \
><<<blocks, THREADS, SHARED_MEM, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, \
prob_m, prob_n, prob_k, \
locks \
); \
}
const int ERR_PROB_SHAPE = 1;
const int ERR_KERN_SHAPE = 2;
int marlin_cuda(
const void* A,
const void* B,
void* C,
void* s,
int prob_m,
int prob_n,
int prob_k,
void* workspace,
int groupsize = -1,
int dev = 0,
cudaStream_t stream = 0,
int thread_k = -1,
int thread_n = -1,
int sms = -1,
int max_par = 16
) {
int tot_m = prob_m;
int tot_m_blocks = ceildiv(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
if (sms == -1)
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
if (thread_k == -1 || thread_n == -1) {
if (prob_m <= 16) {
// For small batchizes, better partioning is slightly more important than better compute utilization
thread_k = 128;
thread_n = 128;
} else {
thread_k = 64;
thread_n = 256;
}
}
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;
int blocks = sms;
if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))
return ERR_PROB_SHAPE;
if (prob_m == 0 || prob_n == 0 || prob_k == 0)
return 0;
const int4* A_ptr = (const int4*) A;
const int4* B_ptr = (const int4*) B;
int4* C_ptr = (int4*) C;
const int4* s_ptr = (const int4*) s;
int cols = prob_n / thread_n;
int* locks = (int*) workspace;
int ret = 0;
for (int i = 0; i < tot_m_blocks; i += 4) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > 4) {
// Note that parallel > 1 currently only works for inputs without any padding
par = (16 * thread_m_blocks - pad) / 64;
if (par > max_par)
par = max_par;
prob_m = 64 * par;
i += 4 * (par - 1);
thread_m_blocks = 4;
}
// For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)
// in our testing, however many more are, in principle, possible.
if (false) {}
CALL_IF(1, 8, 8, -1)
CALL_IF(1, 8, 8, 8)
CALL_IF(1, 16, 4, -1)
CALL_IF(1, 16, 4, 8)
CALL_IF(2, 16, 4, -1)
CALL_IF(2, 16, 4, 8)
CALL_IF(3, 16, 4, -1)
CALL_IF(3, 16, 4, 8)
CALL_IF(4, 16, 4, -1)
CALL_IF(4, 16, 4, 8)
else
ret = ERR_KERN_SHAPE;
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
return ret;
}
#endif
#include <cuda.h>
#include <cuda_runtime.h>
int marlin_cuda(
const void* A,
const void* B,
void* C,
void* s,
int prob_m,
int prob_n,
int prob_k,
void* workspace,
int groupsize,
int dev,
cudaStream_t stream,
int thread_k,
int thread_n,
int sms,
int max_par
);
#include <cuda_runtime.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "marlin_repack.cuh"
__global__ void gptq_repack_kernel(
uint32_t* in,
uint32_t* out,
int m,
int n
) {
uint32_t row = blockIdx.x * 2;
uint32_t col = blockIdx.y * 64;
uint32_t t = threadIdx.x;
// marlin packs 4 16x16 blocks one time;
const int pad_len = 18;
__shared__ uint8_t block[4][16][pad_len];
// unpack
int block_idx = t / 8;
int block_offset = t % 8;
for (int offset = block_offset; offset < 16; offset += 8) {
uint32_t v1 = in[row * n + col + block_idx * 16 + offset];
uint32_t v2 = in[(row + 1) * n + col + block_idx * 16 + offset];
#pragma unroll
for (int i = 0; i < 8; i += 1) {
block[block_idx][i][offset] = v1 & 0xf;
v1 >>= 4;
block[block_idx][i + 8][offset] = v2 & 0xf;
v2 >>= 4;
}
}
// repack
// ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
uint32_t srow = (t % 4) * 2;
uint32_t scol = t / 4;
uint32_t idx[8][2];
idx[0][0] = srow; idx[0][1] = scol;
idx[1][0] = srow + 8; idx[1][1] = scol;
idx[2][0] = srow; idx[2][1] = scol + 8;
idx[3][0] = srow + 8; idx[3][1] = scol + 8;
idx[4][0] = srow + 1; idx[4][1] = scol;
idx[5][0] = srow + 9; idx[5][1] = scol;
idx[6][0] = srow + 1; idx[6][1] = scol + 8;
idx[7][0] = srow + 9; idx[7][1] = scol + 8;
#pragma unroll
for (int i = 0; i < 4; i += 1) {
uint32_t v[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
v[j] = block[i][idx[j][0]][idx[j][1]];
}
uint32_t pack = (v[7] << 28) | (v[6] << 24) | (v[5] << 20) | (v[4] << 16) |
(v[3] << 12) | (v[2] << 8) | (v[1] << 4) | v[0];
out[blockIdx.x * n * 2 + blockIdx.y * 128 + t * 4 + i] = pack;
}
}
torch::Tensor gptq_repack(
torch::Tensor W
) {
int m = W.sizes()[0];
int n = W.sizes()[1];
assert(W.is_contiguous());
assert(W.dtype() == at::kInt);
assert(m % 2 == 0);
assert(n % 64 == 0);
auto result = at::empty(
{m / 2, n * 2}, at::TensorOptions().dtype(at::kInt).device(W.device()));
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
const dim3 threads(32);
// marlin packs 16 x 64 block and gptq packs 8 x 1
const dim3 blocks(m / 2, n / 64);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gptq_repack_kernel<<<blocks, threads, 0, stream>>>(
(uint32_t*)W.data_ptr(),
(uint32_t*)result.data_ptr(),
m,
n
);
return result;
}
\ No newline at end of file
#include <torch/all.h>
__global__ void gptq_repack_kernel(
uint32_t* in,
uint32_t* out,
int m,
int n
);
torch::Tensor gptq_repack(
torch::Tensor W
);
\ No newline at end of file
import argparse
import subprocess
import time
import numpy as np
import pandas as pd
import template
from gekko import GEKKO
def mem_model(N, M, T, mu, tu, bits, l1, p, gs, verbose=False):
m = GEKKO() # create GEKKO model
# cinfergen if bits==3:
# tu = tu*3
B = m.Const(value=bits)
TP = m.Const(value=T // p)
k = m.Var(1, integer=True, lb=1)
z = m.Var(1, integer=True, lb=1)
w = m.Var(1, integer=True, lb=1)
y = m.Var(1, integer=True, lb=1)
v = m.Var(1, integer=True, lb=1)
mb = m.Var(mu, integer=True, lb=1)
if gs != -1:
gg = m.Var(1, integer=True, lb=1)
tb = m.Var(tu, integer=True, lb=1, ub=int(T / p))
L = m.Var(integer=True, lb=0, ub=l1)
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
m.Equation(mb * k == M)
if gs != -1:
m.Equation(gs * gg == mb)
# m.Equation(tb * z == T)
m.Equation(tb * z == TP)
m.Equation(mu * w == mb)
m.Equation(tu * y == tb)
# m.Equation(tb * v == tt)
m.Maximize(L)
m.options.SOLVER = 1
m.solver_options = [
"minlp_maximum_iterations 1000", # minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10", # treat minlp as nlp
"minlp_as_nlp 0", # nlp sub-problem max iterations
"nlp_maximum_iterations 100", # 1 = depth first, 2 = breadth first
"minlp_branch_method 2", # maximum deviation from whole number
"minlp_integer_tol 0.00", # covergence tolerance
"minlp_gap_tol 0.01",
]
try:
m.solve(disp=False)
except:
try:
m.solver_options = [
"minlp_maximum_iterations 1000", # minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10", # treat minlp as nlp
"minlp_as_nlp 0", # nlp sub-problem max iterations
"nlp_maximum_iterations 100", # 1 = depth first, 2 = breadth first
"minlp_branch_method 1", # maximum deviation from whole number
"minlp_integer_tol 0.00", # covergence tolerance
"minlp_gap_tol 0.01",
]
m.solve(disp=False)
except:
# mytb = T//p
mytb = tu
if gs != -1:
mymb = gs
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
mymb += gs
while M % mymb != 0:
mymb -= gs
if verbose:
print("Failed to solve, using heuristic. mb = ", mymb, "tb = ", mytb)
return (int(mymb), int(mytb))
else:
mymb = mu
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
mymb += mu
while M % mymb != 0:
mymb -= mu
if verbose:
print("Failed to solve, using heuristic. mb = ", mymb, "tb = ", mytb)
return (int(mymb), int(mytb))
if verbose:
print("mb = ", int(mb.value[0]), "tb = ", int(tb.value[0]))
return (int(mb.value[0]), int(tb.value[0]))
def macros():
return "#include<omp.h>\n#include<cstdint>\n#include<immintrin.h>\n#include<fstream>\n\n#define mymin(a,b) ((a)<(b)?(a):(b))\n#define mymax(a,b) ((a)>(b)?(a):(b))\n"
def print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p, gs=-1):
res = ""
res += "void print_parameters(){\n"
res += f' std::cout << {bits} << "bits," << {n} << "," << {m} << "," << {t} << "," << {nb} << "," << {mb} << "," << {tb} << "," << {nu} << "," << {mu} << "," << {tu} << "," << {unroll} << "," << {p} << "," << {gs} << ",";\n'
res += "}\n"
return res
def print_parameters_module(bits, mu, nu, tu, unroll, p, gs=-1):
res = ""
res += "void print_parameters(){\n"
res += "std::ofstream outfile;\n"
res += 'outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);\n'
res += f'outfile << {bits} << "," << {nu} << "," << {mu} << "," << {tu} << "," << {unroll} << "," << {p} << "," << {gs} << ",";\n'
res += "}\n"
return res
def pack_in(n, m, nb, mb):
res = ""
res += "inline void pack_input(float* A, float* B){\n"
res += " // copy the full matrix A in blocked format into B\n"
res += " uint64_t idx = 0;\n"
res += f" const int N = {n};\n"
res += f" const int M = {m};\n"
res += f" const int nb = {nb};\n"
res += f" const int mb = {mb};\n"
res += " for(int i = 0; i < N; i+=nb){ \n \
for(int j = 0; j < M; j+=mb){\n \
for(int jj = j; jj < mymin(j+mb, M); jj++){\n \
for(int ii = i; ii < mymin(i+nb, N); ii++){\n \
B[idx] = A[ii*M+jj];\n \
idx++;\n \
}\n \
}\n \
}\n \
}\n \
}\n"
return res
def pack_out(n, t, nb, tb):
res = ""
res += "inline void pack_output(float* A, float* B){\n"
res += " // copy the full matrix A in blocked format into B\n"
res += " uint64_t idx = 0;\n"
res += f" const int N = {n};\n"
res += f" const int M = {t};\n"
res += f" const int nb = {nb};\n"
res += f" const int mb = {tb};\n"
res += " for(int i = 0; i < N; i+=nb){ \n \
for(int j = 0; j < M; j+=mb){\n \
for(int ii = i; ii < mymin(i+nb, N); ii++){\n \
for(int jj = j; jj < mymin(j+mb, M); jj++){\n \
B[idx] = A[ii*M+jj];\n \
idx++;\n \
}\n \
}\n \
}\n \
}\n \
}\n"
return res
def pack_qw(m, t, mb, tb, tb1, bits=4, cutoff=-1):
packed = 32 // bits
res = ""
if cutoff == -1:
cutoff = 65
if bits == 3:
res += "inline void pack_qw_inner(int* A, int* B, int cutoff){\n"
res += " // copy the full matrix A in blocked format into B\n"
res += " uint64_t idx = 0;\n"
res += f" const int N = {m // 32 * 3};\n"
res += f" const int M = {t};\n"
res += f" const int nb = {mb // 32 * 3};\n"
res += f"int mb = {int(tb)};\n"
res += " for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n"
# res += "if(tid==cutoff){\n "
# res += f" mb = {tb1};\n"
# res += "}\n"
res += " for(int i = 0; i < N; i+=nb){\n \
for(int ii = i; ii < mymin(i+nb, N); ii+=3){\n \
for(int jj = j; jj < mymin(j+mb, M); jj+=8){\n \
for(int iii = ii; iii < ii + 3; iii++){\n \
for(int jjj = jj; jjj < jj + 8; jjj++){\n \
B[idx] = A[iii*M+jjj];\n \
idx++;\n \
}\n \
}\n \
}\n \
}\n \
}\n \
}\n \
}\n"
res += "inline void pack_qw(int* A, int* B){\n"
res += f" pack_qw_inner(A, B, {cutoff});\n"
res += "}\n"
return res
else:
# in case i do this for python i can just add the n,m,nb,mb as function parameters
res += "inline void pack_qw_inner(int* A, int* B, int cutoff){\n"
res += " // copy the full matrix A in blocked format into B\n"
res += " uint64_t idx = 0;\n"
res += f" const int N = {m // packed};\n"
res += f" const int M = {t};\n"
res += f" const int nb = {mb // packed};\n"
res += f"int mb = {int(tb)};\n"
res += " for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n"
# res += "if(tid==cutoff){\n "
# res += f" mb = {tb1};\n"
# res += "}\n"
res += " for(int i = 0; i < N; i+=nb){\n \
for(int ii = i; ii < mymin(i+nb, N); ii++){\n \
for(int jj = j; jj < mymin(j+mb, M); jj++){\n \
B[idx] = A[ii*M+jj];\n \
idx++;\n \
}\n \
}\n \
}\n"
res += "}\n"
res += "}\n"
res += "inline void pack_qw(int* A, int* B){\n"
res += f" pack_qw_inner(A, B, {cutoff});\n"
res += "}\n"
return res
def block_gs(nu_iter, mu, tu, rho, packed, unroll, bits):
res = ""
i = 0
# unroll = 4 # number of bcasts and unpacks
if bits == 3:
for j in range(0, tu, 8):
res += f"__m256i w0_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}]);\n"
res += f"__m256i w1_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}+8]);\n"
res += f"__m256i w2_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k3*tb/{packed}*3 + jw+{j*3}+16]);\n"
u = 0
first_off = 3
second_off = 2
wid = 0
shift = 0
while u < 32:
if u == 10:
res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{u})*nb + i1+{i}]);\n"
for j in range(0, tu, 8):
res += f"__m256i ws{j}_10 = _mm256_srli_epi32(w0_{j}, {bits*10});\n"
res += f"__m256i temp0_{j} = _mm256_slli_epi32(w1_{j}, 2);\n"
res += f"temp0_{j} = _mm256_and_si256(temp0_{j}, mask);\n"
res += f"ws{j}_10 = _mm256_or_si256(ws{j}_10, temp0_{j});\n"
res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n"
res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n"
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n"
wid = wid + 1
u = u + 1
elif u == 21:
res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{u})*nb + i1+{i}]);\n"
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{u} = _mm256_srli_epi32(w1_{j}, 31);\n"
res += f"__m256i temp1_{j} = _mm256_slli_epi32(w2_{j}, 1);\n"
res += f"temp1_{j} = _mm256_and_si256(temp1_{j}, mask);\n"
res += f"ws{j}_{u} = _mm256_or_si256(ws{j}_{u}, temp1_{j});\n"
res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n"
res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n"
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n"
wid = wid + 1
u = u + 1
for k in range(u, u + second_off):
res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{k})*nb + i1+{i}]);\n"
for k in range(u, u + second_off):
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{wid}_{j}, {bits*k-wid*32-shift});\n"
for j in range(0, tu, 8):
res += f"__m256i wsa{j}_{k} = _mm256_and_si256(ws{j}_{k}, mask);\n"
for j in range(0, tu, 8):
res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n"
for j in range(0, tu, 8):
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n"
u = u + 2
return res
else:
for j in range(0, tu, 8):
res += f"__m256i w{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed} + k*mb*tb/{packed} + k3*tb/{packed} + j1+{j}]);\n"
for u in range(packed - unroll, -1, -unroll):
for k in range(u + unroll - 1, u - 1, -1):
res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k3+{k})*nb + i1+{i}]);\n"
for k in range(u, u + unroll):
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{j}, {bits*k});\n"
for j in range(0, tu, 8):
res += f"__m256i wsa{j}_{k}= _mm256_and_si256(ws{j}_{k}, mask);\n"
for j in range(0, tu, 8):
res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n"
for j in range(0, tu, 8):
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n"
return res
def block(nu_iter, mu, tu, rho, packed, unroll, bits):
res = ""
i = 0
# unroll = 4 # number of bcasts and unpacks
if bits == 3:
for j in range(0, tu, 8):
res += f"__m256i w0_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}]);\n"
res += f"__m256i w1_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}+8]);\n"
res += f"__m256i w2_{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed}*3 + k*mb*tb/{packed}*3 + k2*tb/{packed}*3 + jw+{j*3}+16]);\n"
u = 0
first_off = 3
second_off = 2
wid = 0
shift = 0
while u < 32:
if u == 10:
res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{u})*nb + i1+{i}]);\n"
for j in range(0, tu, 8):
res += f"__m256i ws{j}_10 = _mm256_srli_epi32(w0_{j}, {bits*10});\n"
res += f"__m256i temp0_{j} = _mm256_slli_epi32(w1_{j}, 2);\n"
res += f"temp0_{j} = _mm256_and_si256(temp0_{j}, mask);\n"
res += f"ws{j}_10 = _mm256_or_si256(ws{j}_10, temp0_{j});\n"
res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n"
res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n"
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n"
wid = wid + 1
u = u + 1
elif u == 21:
res += f"__m256 v{i}_{u} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{u})*nb + i1+{i}]);\n"
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{u} = _mm256_srli_epi32(w1_{j}, 31);\n"
res += f"__m256i temp1_{j} = _mm256_slli_epi32(w2_{j}, 1);\n"
res += f"temp1_{j} = _mm256_and_si256(temp1_{j}, mask);\n"
res += f"ws{j}_{u} = _mm256_or_si256(ws{j}_{u}, temp1_{j});\n"
res += f"__m256i wsa{j}_{u} = _mm256_and_si256(ws{j}_{u}, mask);\n"
res += f"__m256 l{j}_{u} = _mm256_cvtepi32_ps(wsa{j}_{u});\n"
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{u}, l{j}_{u}, acc{i}_{j});\n"
wid = wid + 1
u = u + 1
for k in range(u, u + second_off):
res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{k})*nb + i1+{i}]);\n"
for k in range(u, u + second_off):
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{wid}_{j}, {bits*k-wid*32-shift});\n"
for j in range(0, tu, 8):
res += f"__m256i wsa{j}_{k} = _mm256_and_si256(ws{j}_{k}, mask);\n"
for j in range(0, tu, 8):
res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n"
for j in range(0, tu, 8):
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n"
u = u + 2
return res
else:
for j in range(0, tu, 8):
res += f"__m256i w{j} = _mm256_loadu_si256((__m256i*)&W[base_W + j*m/{packed} + k*mb*tb/{packed} + k2*tb/{packed} + j1+{j}]);\n"
for u in range(packed - unroll, -1, -unroll):
for k in range(u + unroll - 1, u - 1, -1):
res += f"__m256 v{i}_{k} = _mm256_set1_ps(input[(i*om+k)*mb*nb + (k2+{k})*nb + i1+{i}]);\n"
for k in range(u, u + unroll):
for j in range(0, tu, 8):
res += f"__m256i ws{j}_{k} = _mm256_srli_epi32(w{j}, {bits*k});\n"
for j in range(0, tu, 8):
res += f"__m256i wsa{j}_{k}= _mm256_and_si256(ws{j}_{k}, mask);\n"
for j in range(0, tu, 8):
res += f"__m256 l{j}_{k} = _mm256_cvtepi32_ps(wsa{j}_{k});\n"
for j in range(0, tu, 8):
res += f"acc{i}_{j} = _mm256_fmadd_ps(v{i}_{k}, l{j}_{k}, acc{i}_{j});\n"
return res
def accumulators_f(nu, tu, gs=False):
accumulators = ""
for i in range(nu):
for j in range(0, tu, 8):
if gs:
accumulators += f"__m256 acc{i}_{j} = _mm256_setzero_ps();\n"
else:
accumulators += (
f"__m256 acc{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n"
)
return accumulators
def stores_f(nu, tu, gs=False):
store = ""
if gs:
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 o{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 s{i}_{j} = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+{j}]);\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 f{i}_{j} = _mm256_fmadd_ps(acc{i}_{j}, s{i}_{j}, o{i}_{j});\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], f{i}_{j});\n"
else:
for i in range(nu):
for j in range(0, tu, 8):
store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], acc{i}_{j});\n"
return store
def qforward(
nu,
mu,
tu,
p,
unroll,
bits,
n=0,
m=0,
t=0,
nb=0,
mb=0,
tb=0,
tt=0,
cutoff=-1,
gs=False,
gs_val=-1,
module=True,
):
assert module or (gs and gs_val != -1) or (not gs and gs_val == -1)
if cutoff == -1:
cutoff = p + 1
# packed = 32 // bits
if bits == 3:
packed = 32
loopguard = packed
else:
packed = 32 // bits
loopguard = packed
# compute the parameters from the model
accumulators = accumulators_f(nu, tu, gs)
store = stores_f(nu, tu, gs)
ugemm = ""
if gs:
ugemm += "int j1 = 0;\n"
if bits == 3:
ugemm += "int jw = 0;\n"
ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})"
ugemm += "{\n"
else:
ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n"
ugemm += "for(int k1 = 0; k1 < mb; k1+=gs) {\n"
ugemm += accumulators
ugemm += f"for(int k2 = k1; k2 < k1+gs; k2+={loopguard})\n"
ugemm += "{\n"
ugemm += block(nu, mu, tu, 16, packed, unroll, bits)
ugemm += "}\n"
ugemm += store
ugemm += "}\n"
ugemm += "}\n"
else:
ugemm += "int j1 = 0;\n"
if bits == 3:
ugemm += "int jw = 0;\n"
ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})"
ugemm += "{\n"
else:
ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n"
ugemm += accumulators
ugemm += "for(int k1 = 0; k1 < mb; k1+=mu) {\n"
ugemm += f"for(int k2 = k1; k2 < k1+mu; k2+={loopguard})"
ugemm += "{\n"
ugemm += block(nu, mu, tu, 16, packed, unroll, bits)
ugemm += "}\n"
ugemm += "}\n"
ugemm += store
ugemm += "}\n"
res = ""
res += "inline\n"
if gs:
res += f"void q{bits}gemm_gs(const float* __restrict__ input, \n"
else:
res += f"void q{bits}gemm(const float* __restrict__ input, \n"
res += "const int* __restrict__ W, \n"
res += "const float* __restrict__ scales, \n"
res += "const float* __restrict__ zeros, \n"
res += "const float* __restrict__ bias, \n "
res += "const float* __restrict__ sums, \n "
res += "float* __restrict__ output,\n\
const int n,\n\
const int m,\n\
const int t,\n\
const int nb,\n\
const int mb,\n\
const int tb,\n\
int ogtt,\n"
if gs:
res += "const int gs,\n"
res += "const int cutoff){\n"
res += f"#pragma omp parallel num_threads({p})\n"
res += "{\n"
res += "int tid;\n"
res += f"const int mu = {mu};\n"
res += f"const int nu = {nu};\n"
res += f"const int tu = {tu};\n"
res += "const int on = n / nb;\n"
res += "const int om = m / mb;\n"
mask = (2**bits) - 1
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
if bits == 3:
res += "const __m256i mask4 = _mm256_set1_epi32(4);\n"
res += "const __m256i mask6 = _mm256_set1_epi32(6);\n"
res += "tid = omp_get_thread_num();\n"
res += "int tt = ogtt;\n"
res += "if(tid >= cutoff){\n"
res += "tt -= tb;\n"
res += "}\n"
res += "const int base_output = tid >= cutoff ?\n \
(tid-cutoff)*tt + (tt+tb)*cutoff: \n \
tid*tt;\n" # is this >= cutoff or > cutoff?
if bits != 3:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \
tid*tt*m/{packed};\n"
else:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \
tid*tt*m/{packed}*3;\n"
res += "for(int j = 0; j < tt; j+=tb){\n"
res += "for(int i = 0; i < on; i++) {\n"
res += "for(int k = 0; k < om; k++) {\n"
res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n"
res += ugemm
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "#pragma omp barrier\n"
# res += "#pragma omp for\n"
if gs:
res += "const int ngs = m/gs;\n"
res += "for (int i = 0; i < n; i++) {\n"
res += f"for (int j = 0; j < tt; j+={tu})"
res += "{\n"
for i in range(0, tu, 8):
res += f"__m256 acc{i} = _mm256_setzero_ps();\n"
res += "for (int i1 = 0; i1 < ngs; i1++){\n"
res += "__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);\n"
for i in range(0, tu, 8):
res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + i1* t + j + {i}]);\n"
# if not module:
if bits != 3 or not module:
for i in range(0, tu, 8):
res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + i1 * t + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 zs{i} = _mm256_mul_ps(z{i}, s{i});\n"
for i in range(0, tu, 8):
# if module:
if bits == 3 and module:
res += f"acc{i} = _mm256_fmadd_ps(z{i}, r, acc{i});\n"
else:
res += f"acc{i} = _mm256_fmadd_ps(zs{i}, r, acc{i});\n"
res += "}\n"
for i in range(0, tu, 8):
res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 b{i} = _mm256_loadu_ps(&bias[base_output + j + {i}]);\n"
for i in range(0, tu, 8):
if module:
res += f"__m256 o1{i} = _mm256_sub_ps(o{i}, acc{i});\n"
else:
res += f"__m256 o1{i} = _mm256_add_ps(o{i}, acc{i});\n"
for i in range(0, tu, 8):
res += f"__m256 o2{i} = _mm256_add_ps(o1{i}, b{i});\n"
for i in range(0, tu, 8):
res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
else:
res += "for (int i = 0; i < n; i++) {\n"
res += "__m256 r = _mm256_set1_ps(sums[i]);\n"
res += f"for (int j = 0; j < tt; j+={tu})"
res += "{\n"
for i in range(0, tu, 8):
res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 b{i} = _mm256_loadu_ps(&bias[base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + j + {i}]);\n"
if bits == 3 and module:
for i in range(0, tu, 8):
res += f"__m256 os{i} = _mm256_mul_ps(o{i}, s{i});\n"
for i in range(0, tu, 8):
if module:
if bits == 3:
res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, os{i});\n"
else:
res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, o{i});\n"
else:
res += f"__m256 zr{i} = _mm256_fmadd_ps(z{i}, r, o{i});\n"
for i in range(0, tu, 8):
# j res += f"__m256 o2{i} = _mm256_mul_ps(zr{i}, s{i});\n"
if bits == 3 and module:
res += f"__m256 o2{i} = _mm256_add_ps(zr{i}, b{i});\n"
else:
res += f"__m256 o2{i} = _mm256_fmadd_ps(zr{i}, s{i}, b{i});\n"
for i in range(0, tu, 8):
res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
# wrapper for qgemm if we call from cpp
if module:
if gs:
res += f"inline void forward{bits}_gs_cpu(\n"
else:
res += f"inline void forward{bits}_cpu(\n"
res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n"
res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n"
if gs:
res += "int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){\n"
else:
res += "int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){\n"
res += "int* W = weight.data_ptr<int>();\n"
res += "float* input = in.data_ptr<float>();\n"
res += "float* b = bias.data_ptr<float>();\n"
res += "float* s = scales.data_ptr<float>();\n"
res += "float* z = zeros.data_ptr<float>();\n"
res += "float* r = sums.data_ptr<float>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += "\n"
if gs:
res += f"q{bits}gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);\n"
else:
res += f"q{bits}gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);\n"
res += "}\n"
else:
res += "inline void qforward(const float* __restrict__ input, \n \
const int* __restrict__ W, \n\
const float* __restrict__ scales, \n\
const float* __restrict__ zeros, \n\
const float* __restrict__ bias, \n\
const float* __restrict__ sums, \n\
float* __restrict__ output, \n\
int n, \n \
int m, \n \
int t) {\n"
if gs:
res += f"q{bits}gemm_gs(input, W, scales, zeros, bias, sums, output, n, m, t, {nb}, {mb}, {tb}, {tt}, {gs_val}, {cutoff});\n"
else:
res += f"q{bits}gemm(input, W, scales, zeros, bias, sums, output, n, m, t, {nb}, {mb}, {tb}, {tt}, {cutoff});\n"
res += "}\n"
return res
def gen_model(n, m, t, bits, p, gs):
# get parameters
if bits == 3:
packed = 32
unroll = 3
nu = 1 # args.n
mu = 32
tu = 32
else:
packed = 32 // bits
unroll = 2
nu = 1 # args.n
mu = 16
tu = 32
# compute the parameters from the model
nb = n # it's always small for transformers
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, gs)
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert np.sum(split) == t
split = split.astype(int)
tt = int(split[0])
if split[0] == split[-1]:
cutoff = int(p + 1)
else:
cutoff = int(idx + 1)
if gs == -1:
code = qforward(
nu,
mu,
tu,
p,
unroll,
n=n,
m=m,
t=t,
nb=nb,
mb=mb,
tb=tb,
tt=tt,
bits=bits,
cutoff=cutoff,
module=False,
)
else:
code = qforward(
nu,
mu,
tu,
p,
unroll,
n=n,
m=m,
t=t,
nb=nb,
mb=mb,
tb=tb,
tt=tt,
bits=bits,
cutoff=cutoff,
gs=True,
gs_val=gs,
module=False,
)
code += pack_in(n, m, nb, mb)
# code += pack_qw(m, t, mb, tb, tb, bits=bits)#, cutoff=cutoff)
code += pack_qw(m, t, mb, tb, tu, bits=bits)
code += pack_out(n, t, nb, tb)
code += print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p)
with open("./autogptq_extension/qigen/forward.h", "w") as f:
f.write(macros())
f.write(code)
def gen_and_compile(n, m, t, nb, mb, tb, nu, mu, tu, p, unroll, bits=4, gs=-1, module=False):
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert np.sum(split) == t
split = split.astype(int)
tt = int(split[0])
if split[0] == split[-1]:
cutoff = int(p + 1)
else:
cutoff = int(idx + 1)
if gs == -1:
code = qforward(
nu,
mu,
tu,
p,
unroll,
n=n,
m=m,
t=t,
nb=nb,
mb=mb,
tb=tb,
tt=tt,
bits=bits,
cutoff=cutoff,
module=False,
)
else:
code = qforward(
nu,
mu,
tu,
p,
unroll,
n=n,
m=m,
t=t,
nb=nb,
mb=mb,
tb=tb,
tt=tt,
bits=bits,
cutoff=cutoff,
gs=True,
gs_val=gs,
module=False,
)
code += pack_in(n, m, nb, mb)
code += pack_qw(m, t, mb, tb, tu, bits=bits)
code += pack_out(n, t, nb, tb)
if module:
code += print_parameters_module(bits, mu, nu, tu, unroll, p, gs=gs)
else:
code += print_parameters(bits, n, m, t, nb, mb, tb, mu, nu, tu, unroll, p, gs=gs)
# write the code to a file called forward.h
with open("./autogptq_extension/qigen/forward.h", "w") as f:
f.write(macros())
f.write(code)
# g++ mmm_test.cpp -O3 -ftree-vectorize -mfma -mavx -mavx2 -fno-signaling-nans -fno-trapping-math -fopenmp -o mmm_test
start = time.time()
if not module:
subprocess.check_output(
[
"g++",
"-O3",
"-o",
"./autogptq_extension/qigen/mmm_test",
"./autogptq_extension/qigen/mmm_test.cpp",
"-mavx",
"-mfma",
"-mavx2",
"-ftree-vectorize",
"-fno-signaling-nans",
"-fno-trapping-math",
"-march=native",
"-fopenmp",
]
)
subprocess.check_output(
[
"./autogptq_extension/qigen/mmm_test",
f"{n}",
f"{m}",
f"{t}",
f"{bits}",
f"{gs}",
]
)
else:
subprocess.check_output(
[
"g++",
"-O3",
"-o",
"./autogptq_extension/qigen/mmm",
"./autogptq_extension/qigen/mmm.cpp",
"-mavx",
"-mfma",
"-mavx2",
"-ftree-vectorize",
"-fno-signaling-nans",
"-fno-trapping-math",
"-march=native",
"-fopenmp",
]
)
subprocess.check_output(
[
"./autogptq_extension/qigen/mmm",
f"{n}",
f"{m}",
f"{t}",
f"{bits}",
f"{gs}",
]
)
end = time.time() - start
return end
def grid():
tt = 64
for p in [32]:
# for n in [1, 10]:
for n in [1]:
for m in [4096]:
for t in [4096]:
# for mb in range(1,m):
# for mb in range(32,512,32):
# for mb in [64, 128, 256, 512, 1024, 2048]:
for mb in [512, 1024, 2048]:
if m % mb == 0:
# for tb in range(8,t,8):
# for tb in range(32,512,32):
# for tb in [16, 32, 64]:#, 128, 192, 256]:
# for tb in [32]:#, 128, 192, 256]:
for tb in [128, 256]:
if t % tb == 0:
# for mu in range(32,mb,32):
for mu in [16, 32]:
if mb % mu == 0:
# for tu in range(8,tb,8):
# for tu in [16, 32]:
for tu in [16, 32, 64, 128]:
if tb % tu == 0:
for gs in [-1, 128, 64, 32, 16]:
# for bits in [2, 3, 4]:
for bits in [4, 3, 2]:
if bits == 3:
for u in [5]:
gen_and_compile(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits=bits,
gs=gs,
)
else:
for u in [1, 2, 4, 8]:
gen_and_compile(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits=bits,
gs=gs,
)
def forward_module_gs(nu, mu, tu, p, unroll, bits):
# packed = 32 // bits
if bits == 3:
packed = 32
loopguard = packed
else:
packed = 32 // bits
loopguard = packed
# compute the parameters from the model
accumulators = ""
for i in range(nu):
for j in range(0, tu, 8):
accumulators += f"__m256 acc{i}_{j} = _mm256_setzero_ps();\n"
store = ""
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 o{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 s{i}_{j} = _mm256_loadu_ps(&scales[(k*mb+k1)/gs * t + base_output + j + j1+{j}]);\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"__m256 f{i}_{j} = _mm256_fmadd_ps(acc{i}_{j}, s{i}_{j}, o{i}_{j});\n"
for i in range(nu):
for j in range(0, tu, 8):
store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], f{i}_{j});\n"
ugemm = ""
if bits == 3:
ugemm += "int j1 = 0;\n"
ugemm += "int jw = 0;\n"
ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})"
ugemm += "{\n"
else:
ugemm += "int j1 = 0;\n"
ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n"
ugemm += "for(int k1 = 0; k1 < mb; k1+=gs) {\n"
ugemm += accumulators
ugemm += f"for(int k2 = k1; k2 < k1+gs; k2+={loopguard})\n"
ugemm += "{\n"
ugemm += block(nu, mu, tu, 16, packed, unroll, bits)
ugemm += "}\n"
ugemm += store
ugemm += "}\n"
ugemm += "}\n"
res = ""
res += "inline\n"
res += f"void q{bits}gemm_gs(const float* __restrict__ input, \n"
res += " const int* __restrict__ W, \n \
const float* __restrict__ scales, \n"
res += "const float* __restrict__ zeros, \n"
res += " const float* __restrict__ bias, \n "
res += " const float* __restrict__ sums,\n"
res += " float* __restrict__ output,\n \
const int n,\n \
const int m,\n \
const int t,\n \
const int nb,\n \
const int mb,\n \
const int tb,\n \
int ogtt,\n \
const int gs,\n\
const int cutoff){\n"
res += f"#pragma omp parallel num_threads({p})\n"
res += "{\n"
res += " int tid;\n"
res += f" const int mu = {mu};\n"
res += f" const int nu = {nu};\n"
res += f" const int tu = {tu};\n"
res += " const int on = n / nb;\n"
res += " const int om = m / mb;\n"
mask = (2**bits) - 1
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
if bits == 3:
res += "const __m256i mask4 = _mm256_set1_epi32(4);\n"
res += "const __m256i mask6 = _mm256_set1_epi32(6);\n"
res += "tid = omp_get_thread_num();\n"
res += "int tt = ogtt;\n"
res += "if(tid >= cutoff){\n"
res += "tt -= tb;\n"
res += "}\n"
res += "const int base_output = tid >= cutoff ?\n \
(tid-cutoff)*tt + (tt+tb)*cutoff: \n \
tid*tt;\n" # is this >= cutoff or > cutoff?
if bits != 3:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \
tid*tt*m/{packed};\n"
else:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \
tid*tt*m/{packed}*3;\n"
res += "for(int j = 0; j < tt; j+=tb){\n"
res += "for(int i = 0; i < on; i++) {\n"
res += "for(int k = 0; k < om; k++) {\n"
res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n"
res += ugemm
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "const int ngs = m/gs;\n"
res += "#pragma omp barrier\n"
# res += "#pragma omp for collapse(2)\n"
res += "for (int i = 0; i < n; i++) {\n"
# res += f" for (int j = 0; j < t; j+={tu})"
res += f"for (int j = 0; j < tt; j+={tu})"
res += "{\n"
# for i in range(0,tu,8):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 acc{i} = _mm256_setzero_ps();\n"
res += "for (int i1 = 0; i1 < ngs; i1++){\n"
res += "__m256 r = _mm256_set1_ps(sums[i*ngs + i1]);\n"
for i in range(0, tu, 8):
# res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[i1 * t + j + {i}]);\n"
res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + i1* t + j + {i}]);\n"
# for i in range(0,tu,8):
# res += f"__m256 s{i} = _mm256_loadu_ps(&scales[i1 * t + j + {i}]);\n"
# for i in range(0,tu,8):
# res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n"
# for i in range(0,tu,8):
# res += f"acc{i} = _mm256_fmadd_ps(zr{i}, s{i}, acc{i});\n"
for i in range(0, tu, 8):
res += f"acc{i} = _mm256_fmadd_ps(z{i}, r, acc{i});\n"
# for i in range(0,tu,8):
# res += f"__m256 zr{i} = _mm256_mul_ps(z{i}, r);\n"
# for i in range(0,tu,8):
# res += f"o{i} = _mm256_fnmadd_ps(zr{i}, s{i}, o{i});\n"
res += "}\n"
for i in range(0, tu, 8):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 o1{i} = _mm256_sub_ps(o{i}, acc{i});\n"
for i in range(0, tu, 8):
# res += f"_mm256_storeu_ps(&output[i*t + j + {i}], o1{i});\n"
res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o1{i});\n"
res += " }\n"
res += "}\n"
res += "}\n"
res += "}\n"
# wrapper for qgemm if we call from cpp
res += f"inline void forward{bits}_gs_cpu(\n"
res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n"
res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n"
res += "int N, int M, int T, int nb, int mb, int tb, int tt, int groupsize, int cutoff){\n"
res += "int* W = weight.data_ptr<int>();\n"
res += "float* input = in.data_ptr<float>();\n"
res += "float* b = bias.data_ptr<float>();\n"
res += "float* s = scales.data_ptr<float>();\n"
# res += "int* z = zeros.data_ptr<int>();\n"
res += "float* z = zeros.data_ptr<float>();\n"
res += "float* r = sums.data_ptr<float>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += "\n"
res += f"q{bits}gemm_gs(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, groupsize, cutoff);\n"
res += "}\n"
return res
def forward_module(nu, mu, tu, p, unroll, bits):
# packed = 32 // bits
if bits == 3:
packed = 32
loopguard = packed
else:
packed = 32 // bits
loopguard = packed
# compute the parameters from the model
accumulators = ""
for i in range(nu):
for j in range(0, tu, 8):
accumulators += f"__m256 acc{i}_{j} = _mm256_loadu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}]);\n"
store = ""
for i in range(nu):
for j in range(0, tu, 8):
store += f"_mm256_storeu_ps(&output[base_output + j + (i1+{i})*t + j1+{j}], acc{i}_{j});\n"
ugemm = ""
if bits == 3:
ugemm += "int jw = 0;\n"
ugemm += f"for(; j1 < tb-tu+1; j1+=tu, jw+={tu*3})"
ugemm += "{\n"
else:
ugemm += "for(; j1 < tb-tu+1; j1+=tu) {\n"
ugemm += accumulators
ugemm += "for(int k1 = 0; k1 < mb; k1+=mu) {\n"
ugemm += f"for(int k2 = k1; k2 < k1+mu; k2+={loopguard})"
ugemm += "{\n"
ugemm += block(nu, mu, tu, 16, packed, unroll, bits)
ugemm += "}\n"
ugemm += "}\n"
ugemm += store
ugemm += "}\n"
res = ""
res += "inline\n"
res += f"void q{bits}gemm(const float* __restrict__ input, \n"
res += "const int* __restrict__ W, \n"
res += "const float* __restrict__ scales, \n"
# res += "const int* __restrict__ zeros, \n"
res += "const float* __restrict__ zeros, \n"
res += "const float* __restrict__ bias, \n "
res += "const float* __restrict__ sums,"
res += "float* __restrict__ output,\n \
const int n,\n \
const int m,\n \
const int t,\n \
const int nb,\n \
const int mb,\n \
const int tb,\n \
int ogtt,\n \
const int cutoff){\n"
res += f"#pragma omp parallel num_threads({p})\n"
res += "{\n"
res += "int tid, nthreads;\n"
res += f"const int mu = {mu};\n"
res += f"const int nu = {nu};\n"
res += f"const int tu = {tu};\n"
res += "const int on = n / nb;\n"
res += "const int om = m / mb;\n"
mask = (2**bits) - 1
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
if bits == 3:
res += "const __m256i mask4 = _mm256_set1_epi32(4);\n"
res += "const __m256i mask6 = _mm256_set1_epi32(6);\n"
res += "tid = omp_get_thread_num();\n"
# res += " std::cout << \"thread \" << tid << \" started\" << std::endl;\n"
res += "nthreads = omp_get_num_threads();\n"
res += "int tt = ogtt;\n"
res += "if(tid >= cutoff){\n"
res += "tt -= tb;\n"
res += "}\n"
res += "const int base_output = tid >= cutoff ?\n \
(tid-cutoff)*tt + (tt+tb)*cutoff: \n \
tid*tt;\n" # is this >= cutoff or > cutoff?
if bits != 3:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}: \n \
tid*tt*m/{packed};\n"
else:
res += f"const int base_W = tid >= cutoff ?\n \
((tid-cutoff)*tt + (tt+tb)*cutoff)*m/{packed}*3: \n \
tid*tt*m/{packed}*3;\n"
res += "for(int j = 0; j < tt; j+=tb){\n"
res += "for(int i = 0; i < on; i++) {\n"
res += "for(int k = 0; k < om; k++) {\n"
res += "for(int i1 = 0; i1 < nb; i1+=nu) {\n"
res += "int j1 = 0;\n"
res += ugemm
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
# res += "#pragma omp barrier\n"
# res += "#pragma omp for\n"
res += "for (int i = 0; i < n; i++) {\n"
res += "__m256 r = _mm256_set1_ps(sums[i]);\n"
# res += f"for (int j = 0; j < t; j+={tu})"
res += f"for (int j = 0; j < tt; j+={tu})"
res += "{\n"
for i in range(0, tu, 8):
# res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + j + {i}]);\n"
res += f"__m256 o{i} = _mm256_loadu_ps(&output[i*t + base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 z{i} = _mm256_loadu_ps(&zeros[base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 s{i} = _mm256_loadu_ps(&scales[base_output + j + {i}]);\n"
for i in range(0, tu, 8):
res += f"__m256 zr{i} = _mm256_fnmadd_ps(z{i}, r, o{i});\n"
for i in range(0, tu, 8):
res += f"__m256 o2{i} = _mm256_mul_ps(zr{i}, s{i});\n"
for i in range(0, tu, 8):
res += f"_mm256_storeu_ps(&output[i*t + base_output + j + {i}], o2{i});\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
# wrapper for qgemm if we call from cpp
res += f"inline void forward{bits}_cpu(\n"
res += "torch::Tensor in, torch::Tensor weight, torch::Tensor out,\n"
res += "torch::Tensor bias, torch::Tensor scales, torch::Tensor zeros, torch::Tensor sums,\n"
res += "int N, int M, int T, int nb, int mb, int tb, int tt, int cutoff){\n"
res += "int* W = weight.data_ptr<int>();\n"
res += "float* input = in.data_ptr<float>();\n"
res += "float* b = bias.data_ptr<float>();\n"
res += "float* s = scales.data_ptr<float>();\n"
# res += "int* z = zeros.data_ptr<int>();\n"
res += "float* z = zeros.data_ptr<float>();\n"
res += "float* r = sums.data_ptr<float>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += "\n"
res += f"q{bits}gemm(input, W, s, z, b, r, O, N, M, T, nb, mb, tb, tt, cutoff);\n"
res += "}\n"
return res
def unpack_zeros(bits):
res = ""
res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)"
packed = 32 // bits
mask = (2**bits) - 1
res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n"
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
if bits == 4:
res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n"
elif bits == 3:
pass
elif bits == 2:
res += "const __m256i shift0 = _mm256_set_epi32(30,28,26,24,22,20,18,16);\n"
res += "const __m256i shift1 = _mm256_set_epi32(14,12,10,8,6,4,2,0);\n"
else:
print("ERROR")
res += "for(int i = 0; i < n; i++){\n"
if bits == 4:
res += "for(int j = 0; j < m; j+=8){\n"
res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n"
res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n"
res += "__m256i z1 = _mm256_and_si256(z0, mask);\n"
res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n"
res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n"
res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n"
elif bits == 2:
res += f"for (int j = 0; j < m; j+={packed})"
res += "{\n"
res += f"for (int k = 0; k < {packed}; k++)"
res += "{\n"
res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n"
res += "}\n"
# res += "for(int j = 0; j < m; j+=16){\n"
# res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"
# res += "__m256i z00 = _mm256_srlv_epi32(z, shift0);\n"
# res += "__m256i z01 = _mm256_srlv_epi32(z, shift1);\n"
# res += "__m256i z10 = _mm256_and_si256(z00, mask);\n"
# res += "__m256i z11 = _mm256_and_si256(z01, mask);\n"
# res += "__m256i z20 = _mm256_add_epi32(z10, ones);\n"
# res += "__m256i z21 = _mm256_add_epi32(z11, ones);\n"
# res += "__m256 z30 = _mm256_cvtepi32_ps(z20);\n"
# res += "__m256 z31 = _mm256_cvtepi32_ps(z21);\n"
# res += "_mm256_storeu_ps(&ov[i*m +j], z30);\n"
# res += "_mm256_storeu_ps(&ov[i*m +j+8], z31);\n"
elif bits == 3:
# pass
res += "for(int j = 0; j < m; j+=32){\n"
res += 'std::cout<<"not yet implemented"<<std::endl;\n'
# res += "unsigned int z0 = zv[i*m+j/32*3];\n"
# res += "unsigned int z1 = zv[i*m+j/32*3+1];\n"
# res += "unsigned int z2 = zv[i*m+j/32*3+2];\n"
# for i in range(10):
# res += f"unsigned int z0{i} = ((z0 >> {29 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {i}] = z0{i} * sv[i*m + j + {i}];\n"
# res += "unsigned int t0 = ((z0<<1 & 6) | (z1>>31)) + 1;\n"
# res += "ov[i*m + j + 10] = t0 * sv[i*m + j + 10];\n"
# for i in range(10):
# res += f"unsigned int z1{i} = ((z1 >> {28 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {11 + i}] = z1{i} * sv[i*m + j + {11 + i}];\n"
# res += "unsigned int t1 = ((z1<<2 & 6) | (z2>>30)) + 1;\n"
# res += "ov[i*m + j + 21] = t1 * sv[i*m + j + 21];\n"
# for i in range(10):
# res += f"unsigned int z2{i} = ((z2 >> {27 - i*3}) & 7) + 1;\n"
# for i in range(10):
# res += f"ov[i*m + j + {22 + i}] = z2{i} * sv[i*m + j + {22 + i}];\n"
res += "}\n"
res += "}\n"
res += "}\n"
# write the pybind interface
res += f"void unpack_zeros{bits}(torch::Tensor zeros, torch::Tensor out, int N, int M)"
res += "{\nint* Z = zeros.data_ptr<int>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += f"unpack_zeros{bits}_cpu(Z, O, N, M);\n"
res += "}\n"
return res
def gen_module(r, p, bits_list=[2, 3, 4]):
code = ""
for bits in bits_list:
if bits == 3:
unroll = 3
nu = 1 # args.n
mu = 32
tu = 32
else:
unroll = 2
nu = 1 # args.n
mu = 16
# mu = 32
tu = 32
code += qforward(nu, mu, tu, p, unroll, bits=bits, module=True, gs=False)
code += qforward(nu, mu, tu, p, unroll, bits=bits, module=True, gs=True)
code += pack_qw_module(bits)
code += unpack_zeros(bits)
with open("./autogptq_extension/qigen/backend.cpp", "w") as f:
f.write(template.includes())
f.write(template.quant_scalar())
f.write(compute_reduction(p))
f.write(unquantize_sim(p))
f.write(code)
f.write(template.module(bits_list))
def compute_reduction(p):
res = ""
res += "void compute_reduction_cpu(const float* in, float* out, int n, int m, int gs){\n"
res += f"#pragma omp parallel num_threads({p})\n"
res += "{\n"
res += "#pragma omp for collapse(2)\n"
res += "for(int i = 0; i < n; i++){\n"
res += "for(int j0 = 0; j0 < m; j0+=gs){\n"
res += "__m256 acc = _mm256_setzero_ps();\n"
res += "for(int j1 = j0; j1 < j0+gs; j1+=8){\n"
res += "__m256 x = _mm256_loadu_ps(&in[i*m + j1]);\n"
res += "acc = _mm256_add_ps(acc, x);\n"
res += "}\n"
# compute simd add reduction
res += "const __m128 hiQuad = _mm256_extractf128_ps(acc, 1);\n"
res += "const __m128 loQuad = _mm256_castps256_ps128(acc);\n"
res += "const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);\n"
res += "const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);\n"
res += "const __m128 sumDual = _mm_add_ps(sumQuad, hiDual);\n"
res += "const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);\n"
res += "const __m128 sum = _mm_add_ss(hi, sumDual);\n"
res += "out[(i*m + j0)/gs] = _mm_cvtss_f32(sum);\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
# write the pybind interface
res += "void compute_reduction(torch::Tensor in, torch::Tensor out, int N, int M, int gs)"
res += "{\nfloat* I = in.data_ptr<float>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += "compute_reduction_cpu(I, O, N, M, gs);\n"
res += "}\n"
return res
def unquantize_sim(p):
res = ""
res += "void unquantize_sim_cpu(const int* in, float* out, float* s, float* z, int n, int m, int bits, int gs){\n"
res += f"#pragma omp parallel num_threads({p})\n"
res += "{\n"
res += "int packed = 32/bits;\n"
res += "int mask = (1<<bits) - 1;\n"
res += "#pragma omp for\n"
res += "for(int i0 = 0; i0 < n; i0+=gs){\n"
res += "int row = i0 / gs;\n"
res += "for(int i1 = i0; i1 < i0+gs; i1+=packed){\n"
res += "for(int j0 = 0; j0 < m; j0++){\n"
res += "for(int k = 0; k < packed; k++){\n"
res += "out[(i1+k)*m + j0] = ((float)((in[i1*m/packed + j0] >> (bits*k)) & mask) - z[(row)*m + j0]) * s[(row)*m + j0];\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
res += "}\n"
# write the pybind interface
res += "void unquantize_sim(torch::Tensor in, torch::Tensor out, torch::Tensor s, torch::Tensor z, int N, int M, int bits, int gs)"
res += "{\nint* I = in.data_ptr<int>();\n"
res += "float* O = out.data_ptr<float>();\n"
res += "float* S = s.data_ptr<float>();\n"
res += "float* Z = z.data_ptr<float>();\n"
res += "unquantize_sim_cpu(I, O, S, Z, N, M, bits, gs);\n"
res += "}\n"
return res
def pack_qw_module(bits):
packed = 32 // bits
res = ""
if bits == 3:
res += f"inline void pack{bits}_qw_inner(int* A, int* B, const int N, const int M, const int nb, const int mb, int cutoff)"
res += "{\n"
res += "// copy the full matrix A in blocked format into B\n"
res += "uint64_t idx = 0;\n"
# res += f" const {int(tb)};\n"
res += "for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n"
res += "for(int i = 0; i < N; i+=nb){\n \
for(int ii = i; ii < mymin(i+nb, N); ii+=3){\n \
for(int jj = j; jj < mymin(j+mb, M); jj+=8){\n \
for(int iii = ii; iii < ii + 3; iii++){\n \
for(int jjj = jj; jjj < jj + 8; jjj++){\n \
B[idx] = A[iii*M+jjj];\n \
idx++;\n \
}\n \
}\n \
}\n \
}\n \
}\n \
}\n \
}\n"
res += f"inline void pack{bits}_w_cpu(\n"
res += "torch::Tensor in, torch::Tensor out,\n"
res += "int N, int M, int nb, int mb, int cutoff){\n"
res += "int* input = in.data_ptr<int>();\n"
res += "int* O = out.data_ptr<int>();\n"
res += f"pack{bits}_qw_inner(input, O, N, M, nb, mb, cutoff);\n"
res += "}\n"
return res
else:
# in case i do this for python i can just add the n,m,nb,mb as function parameters
res += f"inline void pack{bits}_qw_inner(int* A, int* B, const int N, const int M, const int nb, int mb, int cutoff)"
res += "{\n"
res += "// copy the full matrix A in blocked format into B\n"
res += "uint64_t idx = 0;\n"
res += "for(int j = 0, tid = 0; j < M; j+=mb, tid++){\n"
res += "for(int i = 0; i < N; i+=nb){\n \
for(int ii = i; ii < mymin(i+nb, N); ii++){\n \
for(int jj = j; jj < mymin(j+mb, M); jj++){\n \
B[idx] = A[ii*M+jj];\n \
idx++;\n \
}\n \
}\n \
}\n"
res += "}\n"
res += "}\n"
res += f"inline void pack{bits}_w_cpu(\n"
res += "torch::Tensor in, torch::Tensor out,\n"
res += "int N, int M, int nb, int mb, int cutoff){\n"
res += "int* input = in.data_ptr<int>();\n"
res += "int* O = out.data_ptr<int>();\n"
res += f" pack{bits}_qw_inner(input, O, N, M, nb, mb, cutoff);\n"
res += "}\n"
return res
def gen_module_search(r, p, bits_list=[2, 3, 4]):
# print measurements to a tmp file and read back best micro parameters
code = ""
# Opening in 'w' mode overwrites tmp.csv.
with open("./autogptq_extension/qigen/tmp.csv", "w") as f:
f.write("bits,nu,mu,tu,unroll,p,gs,time\n")
n, m, t, nb, mb, tb = 1, 4096, 4096, 1, 1024, 32
for mu in [16]:
for tu in [16, 32, 64]:
if tb % tu == 0:
for gs in [-1, 64]:
for bits in [4, 3, 2]:
if bits == 3:
for u in [5]:
print(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits,
gs,
end="\r",
flush=True,
)
gen_and_compile(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits=bits,
gs=gs,
module=True,
)
else:
for u in [1, 2, 4, 8]:
print(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits,
gs,
end="\r",
flush=True,
)
gen_and_compile(
n,
m,
t,
n,
mb,
tb,
1,
mu,
tu,
p,
u,
bits=bits,
gs=gs,
module=True,
)
df = pd.read_csv("./autogptq_extension/qigen/tmp.csv")
for bits in bits_list:
bits_df = df[df["bits"] == bits]
bits_nogs = bits_df[bits_df["gs"] == -1]
best = bits_nogs[bits_nogs["time"] == bits_nogs["time"].min()]
nu = int(best["nu"].values[0])
mu = int(best["mu"].values[0])
tu = int(best["tu"].values[0])
unroll = int(best["unroll"].values[0])
code += qforward(nu, mu, tu, p, unroll, bits=bits, module=True, gs=False)
bits_gs = bits_df[bits_df["gs"] != -1]
best = bits_gs[bits_gs["time"] == bits_gs["time"].min()]
nu_gs = int(best["nu"].values[0])
mu_gs = int(best["mu"].values[0])
tu_gs = int(best["tu"].values[0])
unroll_gs = int(best["unroll"].values[0])
code += qforward(nu_gs, mu_gs, tu_gs, p, unroll_gs, bits=bits, module=True, gs=True)
code += pack_qw_module(bits)
code += unpack_zeros(bits)
with open("./autogptq_extension/qigen/backend.cpp", "w") as f:
f.write(template.includes())
f.write(template.quant_scalar())
f.write(compute_reduction(p))
f.write(unquantize_sim(p))
f.write(code)
f.write(template.module(bits_list))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--t", type=int, default=1024)
parser.add_argument("--nb", type=int, default=128)
parser.add_argument("--mb", type=int, default=128)
parser.add_argument("--tb", type=int, default=128)
parser.add_argument("--mu", type=int, default=4)
parser.add_argument("--nu", type=int, default=4)
parser.add_argument("--tu", type=int, default=8)
parser.add_argument("--bits", type=int, default=4)
parser.add_argument("--module", action="store_true")
parser.add_argument("--search", action="store_true")
parser.add_argument("--model", action="store_true")
parser.add_argument("--r", type=int, default=16)
parser.add_argument("--p", type=int, default=8)
parser.add_argument("--gs", type=int, default=-1)
args = parser.parse_args()
if args.module and args.search:
gen_module_search(args.r, args.p, [2, 3, 4])
if args.module and not args.search:
gen_module(args.r, args.p, [2, 3, 4])
if args.search and not args.module:
grid()
if args.model:
gen_model(args.n, args.m, args.t, args.bits, args.p, args.gs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment