Commit eb8e460c authored by nicodafagood's avatar nicodafagood
Browse files

update mygq

parent 23fdbb68
......@@ -92,7 +92,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq','myq', 'squeezellm', None],
choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
......
......@@ -258,7 +258,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq','myq', 'squeezellm', None],
choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
......
......@@ -115,16 +115,16 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);
torch::Tensor myq_gemm(
torch::Tensor mygq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros,
torch::Tensor b_myq_scales,
torch::Tensor b_mygq_qzeros,
torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
void myq_shuffle(
void mygq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
......
......@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("myq_gemm", &myq_gemm, "Quantized GEMM for myq");
ops.def("myq_shuffle", &myq_shuffle, "Post processing for GPTQ");
ops.def("mygq_gemm", &mygq_gemm, "Quantized GEMM for mygq");
ops.def("mygq_shuffle", &mygq_shuffle, "Post processing for mygq");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
......
......@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _compat_cuh
namespace vllm {
namespace myq {
namespace mygq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
......@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
#endif
#endif
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
class MatrixView_half
{
......@@ -269,6 +269,6 @@ public:
}
};
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -21,7 +21,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "qdq_8.cuh"
namespace vllm {
namespace myq {
namespace mygq {
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
......@@ -181,7 +181,7 @@ __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, co
}
typedef void (*fp_gemm_half_q_half_myq_kernel)
typedef void (*fp_gemm_half_q_half_mygq_kernel)
(
const half*,
const uint32_t*,
......@@ -197,12 +197,12 @@ typedef void (*fp_gemm_half_q_half_myq_kernel)
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_4bit_kernel
__global__ void gemm_half_q_half_mygq_4bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
......@@ -213,8 +213,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q4_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x;
......@@ -274,8 +274,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
float scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_f(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
......@@ -292,8 +292,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_f(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
......@@ -307,10 +307,10 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
int4 load_int4 = *b_ptr4;
half2 dq[4][4];
dequant_4bit_8_myq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_myq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_myq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_myq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
dequant_4bit_8_mygq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_mygq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_mygq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_mygq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll
for (int m = 0; m < m_count; m++)
......@@ -339,12 +339,12 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_2bit_kernel
__global__ void gemm_half_q_half_mygq_2bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
......@@ -355,8 +355,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q2_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x;
......@@ -414,8 +414,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
// Initial group
int zeros[4];
half scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
......@@ -427,8 +427,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
}
#pragma unroll
......@@ -470,12 +470,12 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_3bit_kernel
__global__ void gemm_half_q_half_mygq_3bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
......@@ -486,8 +486,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q3_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x;
......@@ -545,8 +545,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
// Initial group
int zeros[4];
half scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
......@@ -558,8 +558,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
}
#pragma unroll
......@@ -601,12 +601,12 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
}
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_8bit_kernel
__global__ void gemm_half_q_half_mygq_8bit_kernel
(
const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
half* __restrict__ c,
const int size_m,
const int size_n,
......@@ -617,8 +617,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q8_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x;
......@@ -676,8 +676,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
// Initial group
int zeros[4];
half scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
// Column result
half block_c[m_count][4] = {};
......@@ -689,8 +689,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4(scales, group, n);
}
#pragma unroll
......@@ -728,15 +728,15 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
}
}
fp_gemm_half_q_half_myq_kernel pick_gemm_half_q_half_myq_kernel(
fp_gemm_half_q_half_mygq_kernel pick_gemm_half_q_half_mygq_kernel(
bool first_block, const int m_count, const int bit)
{
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_myq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_myq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_myq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_myq_8bit_kernel<true, M_COUNT>; \
if (bit == 2) return gemm_half_q_half_mygq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_mygq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_mygq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_mygq_8bit_kernel<true, M_COUNT>; \
}
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL(1);
......@@ -770,8 +770,8 @@ void gemm_half_q_half_cuda_part
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros,
const half* b_myq_scales,
const uint32_t* b_mygq_qzeros,
const half* b_mygq_scales,
const int* b_q_perm,
half* c,
int size_m,
......@@ -790,15 +790,15 @@ void gemm_half_q_half_cuda_part
gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_myq_kernel kernel = pick_gemm_half_q_half_myq_kernel(true, m_count, bit);
fp_gemm_half_q_half_mygq_kernel kernel = pick_gemm_half_q_half_mygq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
b_myq_qzeros,
b_myq_scales,
b_mygq_qzeros,
b_mygq_scales,
c,
size_m,
size_n,
......@@ -813,8 +813,8 @@ __global__ void reconstruct_exllama_8bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
const int size_k,
const int size_n,
const int groups,
......@@ -822,8 +822,8 @@ __global__ void reconstruct_exllama_8bit_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q8_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -857,8 +857,8 @@ __global__ void reconstruct_exllama_8bit_kernel
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads();
......@@ -871,8 +871,8 @@ __global__ void reconstruct_exllama_8bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 4; p++)
......@@ -915,8 +915,8 @@ __global__ void reconstruct_exllama_4bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
const int size_k,
const int size_n,
const int groups,
......@@ -924,8 +924,8 @@ __global__ void reconstruct_exllama_4bit_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q4_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -961,8 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel
half2 scales[4];
half2 z1z16[4][2];
half2 y1y16[4][2];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
......@@ -979,8 +979,8 @@ __global__ void reconstruct_exllama_4bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
......@@ -993,10 +993,10 @@ __global__ void reconstruct_exllama_4bit_kernel
const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4;
dequant_4bit_8_myq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_myq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_myq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_myq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
dequant_4bit_8_mygq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_mygq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_mygq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_mygq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n;
//half* dqh = (half*)dq;
......@@ -1027,8 +1027,8 @@ __global__ void reconstruct_exllama_3bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
const int size_k,
const int size_n,
const int groups,
......@@ -1036,8 +1036,8 @@ __global__ void reconstruct_exllama_3bit_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q3_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -1071,8 +1071,8 @@ __global__ void reconstruct_exllama_3bit_kernel
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads();
......@@ -1085,8 +1085,8 @@ __global__ void reconstruct_exllama_3bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 1; p++)
......@@ -1129,8 +1129,8 @@ __global__ void reconstruct_exllama_2bit_kernel
(
const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros,
const half* __restrict__ b_myq_scales,
const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_mygq_scales,
const int size_k,
const int size_n,
const int groups,
......@@ -1138,8 +1138,8 @@ __global__ void reconstruct_exllama_2bit_kernel
)
{
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_myq_qzeros_(b_myq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n);
MatrixView_q2_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
......@@ -1173,8 +1173,8 @@ __global__ void reconstruct_exllama_2bit_kernel
// Initial zeros/scale
int zeros[4];
half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads();
......@@ -1187,8 +1187,8 @@ __global__ void reconstruct_exllama_2bit_kernel
{
group++;
nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n);
b_mygq_qzeros_.item4(zeros, group, n);
b_mygq_scales_.item4_h2(scales, group, n);
}
for (int p = 0; p < 2; p++)
......@@ -1230,8 +1230,8 @@ __global__ void reconstruct_exllama_2bit_kernel
void reconstruct_exllama
(
const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros,
const half* b_myq_scales,
const uint32_t* b_mygq_qzeros,
const half* b_mygq_scales,
const int* b_q_perm,
half* out,
int height,
......@@ -1260,8 +1260,8 @@ void reconstruct_exllama
(
b_q_weight,
b_q_perm,
b_myq_qzeros,
b_myq_scales,
b_mygq_qzeros,
b_mygq_scales,
height,
width,
groups,
......@@ -1461,8 +1461,8 @@ void gemm_half_q_half_alt
(
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros,
const half* b_myq_scales,
const uint32_t* b_mygq_qzeros,
const half* b_mygq_scales,
const int* b_g_idx,
half* c,
int size_m,
......@@ -1490,8 +1490,8 @@ void gemm_half_q_half_alt
(const half2*) a,
b_q_weight,
c,
b_myq_scales,
b_myq_qzeros,
b_mygq_scales,
b_mygq_qzeros,
b_g_idx,
size_m,
size_k / 32 * bit,
......@@ -1500,7 +1500,7 @@ void gemm_half_q_half_alt
}
template<class T, int bit>
__global__ void reconstruct_myq_kernel
__global__ void reconstruct_mygq_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
......@@ -1538,7 +1538,7 @@ __global__ void reconstruct_myq_kernel
}
}
__global__ void reconstruct_myq_3bit_kernel
__global__ void reconstruct_mygq_3bit_kernel
(
const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
......@@ -1589,11 +1589,11 @@ __global__ void reconstruct_myq_3bit_kernel
}
}
void reconstruct_myq
void reconstruct_mygq
(
const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros,
const half* b_myq_scales,
const uint32_t* b_mygq_qzeros,
const half* b_mygq_scales,
const int* b_g_idx,
half* out,
int height,
......@@ -1608,13 +1608,13 @@ void reconstruct_myq
gridDim.y = DIVIDE(height, 32 / bit);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto kernel = reconstruct_myq_kernel<MatrixView_q4_row, 4>;
auto kernel = reconstruct_mygq_kernel<MatrixView_q4_row, 4>;
if (bit == 2) {
kernel = reconstruct_myq_kernel<MatrixView_q2_row, 2>;
kernel = reconstruct_mygq_kernel<MatrixView_q2_row, 2>;
} else if (bit == 8) {
kernel = reconstruct_myq_kernel<MatrixView_q8_row, 8>;
kernel = reconstruct_mygq_kernel<MatrixView_q8_row, 8>;
} else if (bit == 3) {
kernel = reconstruct_myq_3bit_kernel;
kernel = reconstruct_mygq_3bit_kernel;
gridDim.y = DIVIDE(height, 32);
}
......@@ -1622,8 +1622,8 @@ void reconstruct_myq
kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_myq_scales,
b_myq_qzeros,
b_mygq_scales,
b_mygq_qzeros,
b_g_idx,
height,
width,
......@@ -1638,8 +1638,8 @@ void gemm_half_q_half_cuda
cublasHandle_t cublas_handle,
const half* a,
const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros,
const half* b_myq_scales,
const uint32_t* b_mygq_qzeros,
const half* b_mygq_scales,
const int* b_g_idx,
half* c,
half* temp_dq,
......@@ -1661,12 +1661,12 @@ void gemm_half_q_half_cuda
if (use_reconstruct) {
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx, temp_dq,
reconstruct_exllama(b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx, temp_dq,
size_k, size_n, groups, bit);
}
else
{
reconstruct_myq(b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx,
reconstruct_mygq(b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit);
}
......@@ -1689,22 +1689,22 @@ void gemm_half_q_half_cuda
if (max_chunks)
{
gemm_half_q_half_cuda_part(a, b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx,
gemm_half_q_half_cuda_part(a, b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups, bit);
}
if (last_chunk_size)
{
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_myq_qzeros,
b_myq_scales, b_g_idx, c + last_chunk * size_n,
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_mygq_qzeros,
b_mygq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size,
groups, bit);
}
}
else
{
gemm_half_q_half_alt(a, b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx,
gemm_half_q_half_alt(a, b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
c, size_m, size_n, size_k, bit);
}
}
......@@ -2020,15 +2020,15 @@ void shuffle_exllama_weight
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
torch::Tensor myq_gemm
torch::Tensor mygq_gemm
(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros,
torch::Tensor b_myq_scales,
torch::Tensor b_mygq_qzeros,
torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit
......@@ -2039,27 +2039,27 @@ torch::Tensor myq_gemm
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
vllm::myq::gemm_half_q_half_cuda
vllm::mygq::gemm_half_q_half_cuda
(
at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(),
(const uint32_t*) b_q_weight.data_ptr(),
(const uint32_t*)b_myq_qzeros.data_ptr(),
(const half*) b_myq_scales.data_ptr(),
(const uint32_t*)b_mygq_qzeros.data_ptr(),
(const half*) b_mygq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
(half*) c.data_ptr(),
(half*) temp_dq.data_ptr(),
c.size(0), // m
c.size(1), // n
a.size(1), // k
b_myq_qzeros.size(0), // group number
b_mygq_qzeros.size(0), // group number
use_exllama,
bit
);
return c;
}
void myq_shuffle
void mygq_shuffle
(
torch::Tensor q_weight,
torch::Tensor q_perm,
......@@ -2067,7 +2067,7 @@ void myq_shuffle
)
{
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::myq::shuffle_exllama_weight(
vllm::mygq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit,
......
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
......@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -4,7 +4,7 @@
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
......@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
......@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
}
__forceinline__ __device__ void dequant_4bit_8_myq
__forceinline__ __device__ void dequant_4bit_8_mygq
(
const uint32_t q_0,
half2 (&dq)[4],
......@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
__forceinline__ __device__ void shuffle_8bit_4
(
......@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _qdq_util_cuh
namespace vllm {
namespace myq {
namespace mygq {
union half2_uint32
{
......@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -339,7 +339,7 @@ vllm_extension_sources = [
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/quantization/myq/q_gemm.cu",
"csrc/quantization/mygq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp",
......
......@@ -155,7 +155,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","myq"]
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","mygq"]
rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......
......@@ -208,7 +208,7 @@ class EngineArgs:
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', 'gptq', 'squeezellm','myq', None],
choices=['awq', 'gptq', 'squeezellm','mygq', None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
......
......@@ -3,13 +3,13 @@ from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.myq import MYQConfig
from vllm.model_executor.layers.quantization.mygq import MYQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"myq": MYQConfig,
"mygq": MYQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
......
......@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
@classmethod
def get_name(cls) -> str:
return "myq"
return "mygq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
......@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.myq_shuffle(weights["qweight"], weights["g_idx"],
ops.mygq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.myq_gemm(reshaped_x, weights["qweight"],
output = ops.mygq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment