"src/include/ConstantMatrixDescriptor.hpp" did not exist on "766b0a9eafe29a5d2a75c350345e54165ceaf405"
Commit eb8e460c authored by nicodafagood's avatar nicodafagood
Browse files

update mygq

parent 23fdbb68
...@@ -92,7 +92,7 @@ if __name__ == '__main__': ...@@ -92,7 +92,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq','myq', 'squeezellm', None], choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
......
...@@ -258,7 +258,7 @@ if __name__ == "__main__": ...@@ -258,7 +258,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq','myq', 'squeezellm', None], choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
......
...@@ -115,16 +115,16 @@ void gptq_shuffle( ...@@ -115,16 +115,16 @@ void gptq_shuffle(
torch::Tensor q_perm, torch::Tensor q_perm,
int bit); int bit);
torch::Tensor myq_gemm( torch::Tensor mygq_gemm(
torch::Tensor a, torch::Tensor a,
torch::Tensor b_q_weight, torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros, torch::Tensor b_mygq_qzeros,
torch::Tensor b_myq_scales, torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx, torch::Tensor b_g_idx,
bool use_exllama, bool use_exllama,
int bit); int bit);
void myq_shuffle( void mygq_shuffle(
torch::Tensor q_weight, torch::Tensor q_weight,
torch::Tensor q_perm, torch::Tensor q_perm,
int bit); int bit);
......
...@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("myq_gemm", &myq_gemm, "Quantized GEMM for myq"); ops.def("mygq_gemm", &mygq_gemm, "Quantized GEMM for mygq");
ops.def("myq_shuffle", &myq_shuffle, "Post processing for GPTQ"); ops.def("mygq_shuffle", &mygq_shuffle, "Post processing for mygq");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def( ops.def(
"moe_align_block_size", "moe_align_block_size",
......
...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _compat_cuh #define _compat_cuh
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// atomicAdd for half types, to support CC < 7.x // atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) __device__ __forceinline__ void atomicAdd_half(half* address, half val)
...@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd ...@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
#endif #endif
#endif #endif
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo ...@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
class MatrixView_half class MatrixView_half
{ {
...@@ -269,6 +269,6 @@ public: ...@@ -269,6 +269,6 @@ public:
} }
}; };
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -21,7 +21,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq ...@@ -21,7 +21,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "qdq_8.cuh" #include "qdq_8.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
#define BLOCK_KN_SIZE 128 #define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8 #define BLOCK_M_SIZE_MAX 8
...@@ -181,7 +181,7 @@ __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, co ...@@ -181,7 +181,7 @@ __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, co
} }
typedef void (*fp_gemm_half_q_half_myq_kernel) typedef void (*fp_gemm_half_q_half_mygq_kernel)
( (
const half*, const half*,
const uint32_t*, const uint32_t*,
...@@ -197,12 +197,12 @@ typedef void (*fp_gemm_half_q_half_myq_kernel) ...@@ -197,12 +197,12 @@ typedef void (*fp_gemm_half_q_half_myq_kernel)
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_4bit_kernel __global__ void gemm_half_q_half_mygq_4bit_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
half* __restrict__ c, half* __restrict__ c,
const int size_m, const int size_m,
const int size_n, const int size_n,
...@@ -213,8 +213,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel ...@@ -213,8 +213,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q4_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x; int t = threadIdx.x;
...@@ -274,8 +274,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel ...@@ -274,8 +274,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
float scales[4]; float scales[4];
half2 z1z16[4][2]; half2 z1z16[4][2];
half2 y1y16[4][2]; half2 y1y16[4][2];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_f(scales, group, n); b_mygq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -292,8 +292,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel ...@@ -292,8 +292,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_f(scales, group, n); b_mygq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -307,10 +307,10 @@ __global__ void gemm_half_q_half_myq_4bit_kernel ...@@ -307,10 +307,10 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
half2 dq[4][4]; half2 dq[4][4];
dequant_4bit_8_myq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); dequant_4bit_8_mygq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_myq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); dequant_4bit_8_mygq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_myq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); dequant_4bit_8_mygq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_myq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); dequant_4bit_8_mygq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) for (int m = 0; m < m_count; m++)
...@@ -339,12 +339,12 @@ __global__ void gemm_half_q_half_myq_4bit_kernel ...@@ -339,12 +339,12 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
} }
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_2bit_kernel __global__ void gemm_half_q_half_mygq_2bit_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
half* __restrict__ c, half* __restrict__ c,
const int size_m, const int size_m,
const int size_n, const int size_n,
...@@ -355,8 +355,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel ...@@ -355,8 +355,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q2_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x; int t = threadIdx.x;
...@@ -414,8 +414,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel ...@@ -414,8 +414,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
// Initial group // Initial group
int zeros[4]; int zeros[4];
half scales[4]; half scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
// Column result // Column result
half block_c[m_count][4] = {}; half block_c[m_count][4] = {};
...@@ -427,8 +427,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel ...@@ -427,8 +427,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
} }
#pragma unroll #pragma unroll
...@@ -470,12 +470,12 @@ __global__ void gemm_half_q_half_myq_2bit_kernel ...@@ -470,12 +470,12 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
} }
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_3bit_kernel __global__ void gemm_half_q_half_mygq_3bit_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
half* __restrict__ c, half* __restrict__ c,
const int size_m, const int size_m,
const int size_n, const int size_n,
...@@ -486,8 +486,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel ...@@ -486,8 +486,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q3_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x; int t = threadIdx.x;
...@@ -545,8 +545,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel ...@@ -545,8 +545,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
// Initial group // Initial group
int zeros[4]; int zeros[4];
half scales[4]; half scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
// Column result // Column result
half block_c[m_count][4] = {}; half block_c[m_count][4] = {};
...@@ -558,8 +558,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel ...@@ -558,8 +558,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
} }
#pragma unroll #pragma unroll
...@@ -601,12 +601,12 @@ __global__ void gemm_half_q_half_myq_3bit_kernel ...@@ -601,12 +601,12 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
} }
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_myq_8bit_kernel __global__ void gemm_half_q_half_mygq_8bit_kernel
( (
const half* __restrict__ a, const half* __restrict__ a,
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
half* __restrict__ c, half* __restrict__ c,
const int size_m, const int size_m,
const int size_n, const int size_n,
...@@ -617,8 +617,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel ...@@ -617,8 +617,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{ {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q8_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int t = threadIdx.x; int t = threadIdx.x;
...@@ -676,8 +676,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel ...@@ -676,8 +676,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
// Initial group // Initial group
int zeros[4]; int zeros[4];
half scales[4]; half scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
// Column result // Column result
half block_c[m_count][4] = {}; half block_c[m_count][4] = {};
...@@ -689,8 +689,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel ...@@ -689,8 +689,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4(scales, group, n); b_mygq_scales_.item4(scales, group, n);
} }
#pragma unroll #pragma unroll
...@@ -728,15 +728,15 @@ __global__ void gemm_half_q_half_myq_8bit_kernel ...@@ -728,15 +728,15 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
} }
} }
fp_gemm_half_q_half_myq_kernel pick_gemm_half_q_half_myq_kernel( fp_gemm_half_q_half_mygq_kernel pick_gemm_half_q_half_mygq_kernel(
bool first_block, const int m_count, const int bit) bool first_block, const int m_count, const int bit)
{ {
#define SELECT_KERNEL(M_COUNT) \ #define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \ if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_myq_2bit_kernel<true, M_COUNT>; \ if (bit == 2) return gemm_half_q_half_mygq_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_myq_3bit_kernel<true, M_COUNT>; \ if (bit == 3) return gemm_half_q_half_mygq_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_myq_4bit_kernel<true, M_COUNT>; \ if (bit == 4) return gemm_half_q_half_mygq_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_myq_8bit_kernel<true, M_COUNT>; \ if (bit == 8) return gemm_half_q_half_mygq_8bit_kernel<true, M_COUNT>; \
} }
#if BLOCK_M_SIZE_MAX >= 1 #if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL(1); SELECT_KERNEL(1);
...@@ -770,8 +770,8 @@ void gemm_half_q_half_cuda_part ...@@ -770,8 +770,8 @@ void gemm_half_q_half_cuda_part
( (
const half* a, const half* a,
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros, const uint32_t* b_mygq_qzeros,
const half* b_myq_scales, const half* b_mygq_scales,
const int* b_q_perm, const int* b_q_perm,
half* c, half* c,
int size_m, int size_m,
...@@ -790,15 +790,15 @@ void gemm_half_q_half_cuda_part ...@@ -790,15 +790,15 @@ void gemm_half_q_half_cuda_part
gridDim.y = DIVIDE(size_m, m_count); gridDim.y = DIVIDE(size_m, m_count);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
fp_gemm_half_q_half_myq_kernel kernel = pick_gemm_half_q_half_myq_kernel(true, m_count, bit); fp_gemm_half_q_half_mygq_kernel kernel = pick_gemm_half_q_half_mygq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
a, a,
b_q_weight, b_q_weight,
b_myq_qzeros, b_mygq_qzeros,
b_myq_scales, b_mygq_scales,
c, c,
size_m, size_m,
size_n, size_n,
...@@ -813,8 +813,8 @@ __global__ void reconstruct_exllama_8bit_kernel ...@@ -813,8 +813,8 @@ __global__ void reconstruct_exllama_8bit_kernel
( (
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groups, const int groups,
...@@ -822,8 +822,8 @@ __global__ void reconstruct_exllama_8bit_kernel ...@@ -822,8 +822,8 @@ __global__ void reconstruct_exllama_8bit_kernel
) )
{ {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q8_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y; int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -857,8 +857,8 @@ __global__ void reconstruct_exllama_8bit_kernel ...@@ -857,8 +857,8 @@ __global__ void reconstruct_exllama_8bit_kernel
// Initial zeros/scale // Initial zeros/scale
int zeros[4]; int zeros[4];
half2 scales[4]; half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads(); __syncthreads();
...@@ -871,8 +871,8 @@ __global__ void reconstruct_exllama_8bit_kernel ...@@ -871,8 +871,8 @@ __global__ void reconstruct_exllama_8bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
} }
for (int p = 0; p < 4; p++) for (int p = 0; p < 4; p++)
...@@ -915,8 +915,8 @@ __global__ void reconstruct_exllama_4bit_kernel ...@@ -915,8 +915,8 @@ __global__ void reconstruct_exllama_4bit_kernel
( (
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groups, const int groups,
...@@ -924,8 +924,8 @@ __global__ void reconstruct_exllama_4bit_kernel ...@@ -924,8 +924,8 @@ __global__ void reconstruct_exllama_4bit_kernel
) )
{ {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q4_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y; int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -961,8 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel ...@@ -961,8 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel
half2 scales[4]; half2 scales[4];
half2 z1z16[4][2]; half2 z1z16[4][2];
half2 y1y16[4][2]; half2 y1y16[4][2];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -979,8 +979,8 @@ __global__ void reconstruct_exllama_4bit_kernel ...@@ -979,8 +979,8 @@ __global__ void reconstruct_exllama_4bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
...@@ -993,10 +993,10 @@ __global__ void reconstruct_exllama_4bit_kernel ...@@ -993,10 +993,10 @@ __global__ void reconstruct_exllama_4bit_kernel
const int4* b_ptr4 = (int4*) b_ptr; const int4* b_ptr4 = (int4*) b_ptr;
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
dequant_4bit_8_myq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); dequant_4bit_8_mygq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
dequant_4bit_8_myq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); dequant_4bit_8_mygq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
dequant_4bit_8_myq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); dequant_4bit_8_mygq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
dequant_4bit_8_myq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); dequant_4bit_8_mygq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
b_ptr += size_n; b_ptr += size_n;
//half* dqh = (half*)dq; //half* dqh = (half*)dq;
...@@ -1027,8 +1027,8 @@ __global__ void reconstruct_exllama_3bit_kernel ...@@ -1027,8 +1027,8 @@ __global__ void reconstruct_exllama_3bit_kernel
( (
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groups, const int groups,
...@@ -1036,8 +1036,8 @@ __global__ void reconstruct_exllama_3bit_kernel ...@@ -1036,8 +1036,8 @@ __global__ void reconstruct_exllama_3bit_kernel
) )
{ {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q3_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y; int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -1071,8 +1071,8 @@ __global__ void reconstruct_exllama_3bit_kernel ...@@ -1071,8 +1071,8 @@ __global__ void reconstruct_exllama_3bit_kernel
// Initial zeros/scale // Initial zeros/scale
int zeros[4]; int zeros[4];
half2 scales[4]; half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads(); __syncthreads();
...@@ -1085,8 +1085,8 @@ __global__ void reconstruct_exllama_3bit_kernel ...@@ -1085,8 +1085,8 @@ __global__ void reconstruct_exllama_3bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
} }
for (int p = 0; p < 1; p++) for (int p = 0; p < 1; p++)
...@@ -1129,8 +1129,8 @@ __global__ void reconstruct_exllama_2bit_kernel ...@@ -1129,8 +1129,8 @@ __global__ void reconstruct_exllama_2bit_kernel
( (
const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_q_weight,
const int* __restrict__ b_q_perm, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_myq_qzeros, const uint32_t* __restrict__ b_mygq_qzeros,
const half* __restrict__ b_myq_scales, const half* __restrict__ b_mygq_scales,
const int size_k, const int size_k,
const int size_n, const int size_n,
const int groups, const int groups,
...@@ -1138,8 +1138,8 @@ __global__ void reconstruct_exllama_2bit_kernel ...@@ -1138,8 +1138,8 @@ __global__ void reconstruct_exllama_2bit_kernel
) )
{ {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_myq_qzeros_(b_myq_qzeros, groups, size_n); MatrixView_q2_row b_mygq_qzeros_(b_mygq_qzeros, groups, size_n);
MatrixView_half b_myq_scales_(b_myq_scales, groups, size_n); MatrixView_half b_mygq_scales_(b_mygq_scales, groups, size_n);
int offset_k = BLOCK_KN_SIZE * blockIdx.y; int offset_k = BLOCK_KN_SIZE * blockIdx.y;
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
...@@ -1173,8 +1173,8 @@ __global__ void reconstruct_exllama_2bit_kernel ...@@ -1173,8 +1173,8 @@ __global__ void reconstruct_exllama_2bit_kernel
// Initial zeros/scale // Initial zeros/scale
int zeros[4]; int zeros[4];
half2 scales[4]; half2 scales[4];
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
__syncthreads(); __syncthreads();
...@@ -1187,8 +1187,8 @@ __global__ void reconstruct_exllama_2bit_kernel ...@@ -1187,8 +1187,8 @@ __global__ void reconstruct_exllama_2bit_kernel
{ {
group++; group++;
nextgroup += groupsize; nextgroup += groupsize;
b_myq_qzeros_.item4(zeros, group, n); b_mygq_qzeros_.item4(zeros, group, n);
b_myq_scales_.item4_h2(scales, group, n); b_mygq_scales_.item4_h2(scales, group, n);
} }
for (int p = 0; p < 2; p++) for (int p = 0; p < 2; p++)
...@@ -1230,8 +1230,8 @@ __global__ void reconstruct_exllama_2bit_kernel ...@@ -1230,8 +1230,8 @@ __global__ void reconstruct_exllama_2bit_kernel
void reconstruct_exllama void reconstruct_exllama
( (
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros, const uint32_t* b_mygq_qzeros,
const half* b_myq_scales, const half* b_mygq_scales,
const int* b_q_perm, const int* b_q_perm,
half* out, half* out,
int height, int height,
...@@ -1260,8 +1260,8 @@ void reconstruct_exllama ...@@ -1260,8 +1260,8 @@ void reconstruct_exllama
( (
b_q_weight, b_q_weight,
b_q_perm, b_q_perm,
b_myq_qzeros, b_mygq_qzeros,
b_myq_scales, b_mygq_scales,
height, height,
width, width,
groups, groups,
...@@ -1461,8 +1461,8 @@ void gemm_half_q_half_alt ...@@ -1461,8 +1461,8 @@ void gemm_half_q_half_alt
( (
const half* a, const half* a,
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros, const uint32_t* b_mygq_qzeros,
const half* b_myq_scales, const half* b_mygq_scales,
const int* b_g_idx, const int* b_g_idx,
half* c, half* c,
int size_m, int size_m,
...@@ -1490,8 +1490,8 @@ void gemm_half_q_half_alt ...@@ -1490,8 +1490,8 @@ void gemm_half_q_half_alt
(const half2*) a, (const half2*) a,
b_q_weight, b_q_weight,
c, c,
b_myq_scales, b_mygq_scales,
b_myq_qzeros, b_mygq_qzeros,
b_g_idx, b_g_idx,
size_m, size_m,
size_k / 32 * bit, size_k / 32 * bit,
...@@ -1500,7 +1500,7 @@ void gemm_half_q_half_alt ...@@ -1500,7 +1500,7 @@ void gemm_half_q_half_alt
} }
template<class T, int bit> template<class T, int bit>
__global__ void reconstruct_myq_kernel __global__ void reconstruct_mygq_kernel
( (
const uint32_t* __restrict__ w, const uint32_t* __restrict__ w,
const half* __restrict__ w_scales, const half* __restrict__ w_scales,
...@@ -1538,7 +1538,7 @@ __global__ void reconstruct_myq_kernel ...@@ -1538,7 +1538,7 @@ __global__ void reconstruct_myq_kernel
} }
} }
__global__ void reconstruct_myq_3bit_kernel __global__ void reconstruct_mygq_3bit_kernel
( (
const uint32_t* __restrict__ w, const uint32_t* __restrict__ w,
const half* __restrict__ w_scales, const half* __restrict__ w_scales,
...@@ -1589,11 +1589,11 @@ __global__ void reconstruct_myq_3bit_kernel ...@@ -1589,11 +1589,11 @@ __global__ void reconstruct_myq_3bit_kernel
} }
} }
void reconstruct_myq void reconstruct_mygq
( (
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros, const uint32_t* b_mygq_qzeros,
const half* b_myq_scales, const half* b_mygq_scales,
const int* b_g_idx, const int* b_g_idx,
half* out, half* out,
int height, int height,
...@@ -1608,13 +1608,13 @@ void reconstruct_myq ...@@ -1608,13 +1608,13 @@ void reconstruct_myq
gridDim.y = DIVIDE(height, 32 / bit); gridDim.y = DIVIDE(height, 32 / bit);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
auto kernel = reconstruct_myq_kernel<MatrixView_q4_row, 4>; auto kernel = reconstruct_mygq_kernel<MatrixView_q4_row, 4>;
if (bit == 2) { if (bit == 2) {
kernel = reconstruct_myq_kernel<MatrixView_q2_row, 2>; kernel = reconstruct_mygq_kernel<MatrixView_q2_row, 2>;
} else if (bit == 8) { } else if (bit == 8) {
kernel = reconstruct_myq_kernel<MatrixView_q8_row, 8>; kernel = reconstruct_mygq_kernel<MatrixView_q8_row, 8>;
} else if (bit == 3) { } else if (bit == 3) {
kernel = reconstruct_myq_3bit_kernel; kernel = reconstruct_mygq_3bit_kernel;
gridDim.y = DIVIDE(height, 32); gridDim.y = DIVIDE(height, 32);
} }
...@@ -1622,8 +1622,8 @@ void reconstruct_myq ...@@ -1622,8 +1622,8 @@ void reconstruct_myq
kernel<<<gridDim, blockDim, 0, stream>>> kernel<<<gridDim, blockDim, 0, stream>>>
( (
b_q_weight, b_q_weight,
b_myq_scales, b_mygq_scales,
b_myq_qzeros, b_mygq_qzeros,
b_g_idx, b_g_idx,
height, height,
width, width,
...@@ -1638,8 +1638,8 @@ void gemm_half_q_half_cuda ...@@ -1638,8 +1638,8 @@ void gemm_half_q_half_cuda
cublasHandle_t cublas_handle, cublasHandle_t cublas_handle,
const half* a, const half* a,
const uint32_t* b_q_weight, const uint32_t* b_q_weight,
const uint32_t* b_myq_qzeros, const uint32_t* b_mygq_qzeros,
const half* b_myq_scales, const half* b_mygq_scales,
const int* b_g_idx, const int* b_g_idx,
half* c, half* c,
half* temp_dq, half* temp_dq,
...@@ -1661,12 +1661,12 @@ void gemm_half_q_half_cuda ...@@ -1661,12 +1661,12 @@ void gemm_half_q_half_cuda
if (use_reconstruct) { if (use_reconstruct) {
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) { if (use_exllama) {
reconstruct_exllama(b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx, temp_dq, reconstruct_exllama(b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx, temp_dq,
size_k, size_n, groups, bit); size_k, size_n, groups, bit);
} }
else else
{ {
reconstruct_myq(b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx, reconstruct_mygq(b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit); temp_dq, size_k, size_n, groups, bit);
} }
...@@ -1689,22 +1689,22 @@ void gemm_half_q_half_cuda ...@@ -1689,22 +1689,22 @@ void gemm_half_q_half_cuda
if (max_chunks) if (max_chunks)
{ {
gemm_half_q_half_cuda_part(a, b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx, gemm_half_q_half_cuda_part(a, b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
groups, bit); groups, bit);
} }
if (last_chunk_size) if (last_chunk_size)
{ {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_myq_qzeros, gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_mygq_qzeros,
b_myq_scales, b_g_idx, c + last_chunk * size_n, b_mygq_scales, b_g_idx, c + last_chunk * size_n,
last_chunk_size, size_n, size_k, last_chunk_size, last_chunk_size, size_n, size_k, last_chunk_size,
groups, bit); groups, bit);
} }
} }
else else
{ {
gemm_half_q_half_alt(a, b_q_weight, b_myq_qzeros, b_myq_scales, b_g_idx, gemm_half_q_half_alt(a, b_q_weight, b_mygq_qzeros, b_mygq_scales, b_g_idx,
c, size_m, size_n, size_k, bit); c, size_m, size_n, size_k, bit);
} }
} }
...@@ -2020,15 +2020,15 @@ void shuffle_exllama_weight ...@@ -2020,15 +2020,15 @@ void shuffle_exllama_weight
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width); shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
torch::Tensor myq_gemm torch::Tensor mygq_gemm
( (
torch::Tensor a, torch::Tensor a,
torch::Tensor b_q_weight, torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros, torch::Tensor b_mygq_qzeros,
torch::Tensor b_myq_scales, torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx, torch::Tensor b_g_idx,
bool use_exllama, bool use_exllama,
int bit int bit
...@@ -2039,27 +2039,27 @@ torch::Tensor myq_gemm ...@@ -2039,27 +2039,27 @@ torch::Tensor myq_gemm
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
vllm::myq::gemm_half_q_half_cuda vllm::mygq::gemm_half_q_half_cuda
( (
at::cuda::getCurrentCUDABlasHandle(), at::cuda::getCurrentCUDABlasHandle(),
(const half*) a.data_ptr(), (const half*) a.data_ptr(),
(const uint32_t*) b_q_weight.data_ptr(), (const uint32_t*) b_q_weight.data_ptr(),
(const uint32_t*)b_myq_qzeros.data_ptr(), (const uint32_t*)b_mygq_qzeros.data_ptr(),
(const half*) b_myq_scales.data_ptr(), (const half*) b_mygq_scales.data_ptr(),
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
(half*) c.data_ptr(), (half*) c.data_ptr(),
(half*) temp_dq.data_ptr(), (half*) temp_dq.data_ptr(),
c.size(0), // m c.size(0), // m
c.size(1), // n c.size(1), // n
a.size(1), // k a.size(1), // k
b_myq_qzeros.size(0), // group number b_mygq_qzeros.size(0), // group number
use_exllama, use_exllama,
bit bit
); );
return c; return c;
} }
void myq_shuffle void mygq_shuffle
( (
torch::Tensor q_weight, torch::Tensor q_weight,
torch::Tensor q_perm, torch::Tensor q_perm,
...@@ -2067,7 +2067,7 @@ void myq_shuffle ...@@ -2067,7 +2067,7 @@ void myq_shuffle
) )
{ {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::myq::shuffle_exllama_weight( vllm::mygq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(), (uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit, q_weight.size(0) * 32 / bit,
......
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
...@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16 ...@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
dq[7] = __hfma2(q7.as_half2, y64, z64); dq[7] = __hfma2(q7.as_half2, y64, z64);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
// v9997775 55333111 u8886664 44222000 (u, v lsb) // v9997775 55333111 u8886664 44222000 (u, v lsb)
...@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32 ...@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
dq[15] = __hadd2(q15.as_half2, z1); dq[15] = __hadd2(q15.as_half2, z1);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
// 77775555 33331111 66664444 22220000 // 77775555 33331111 66664444 22220000
...@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero ...@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
} }
__forceinline__ __device__ void dequant_4bit_8_myq __forceinline__ __device__ void dequant_4bit_8_mygq
( (
const uint32_t q_0, const uint32_t q_0,
half2 (&dq)[4], half2 (&dq)[4],
...@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq ...@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
} }
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
__forceinline__ __device__ void shuffle_8bit_4 __forceinline__ __device__ void shuffle_8bit_4
( (
...@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8 ...@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _qdq_util_cuh #define _qdq_util_cuh
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
union half2_uint32 union half2_uint32
{ {
...@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i ...@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
return (int)(__funnelshift_rc(q0, q1, shift) & mask); return (int)(__funnelshift_rc(q0, q1, shift) & mask);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -339,7 +339,7 @@ vllm_extension_sources = [ ...@@ -339,7 +339,7 @@ vllm_extension_sources = [
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu", "csrc/quantization/gptq/q_gemm.cu",
"csrc/quantization/myq/q_gemm.cu", "csrc/quantization/mygq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu", "csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
......
...@@ -155,7 +155,7 @@ class ModelConfig: ...@@ -155,7 +155,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","myq"] supported_quantization = ["awq", "gptq", "squeezellm", "marlin","mygq"]
rocm_not_supported_quantization = ["awq", "marlin"] rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
......
...@@ -208,7 +208,7 @@ class EngineArgs: ...@@ -208,7 +208,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'gptq', 'squeezellm','myq', None], choices=['awq', 'gptq', 'squeezellm','mygq', None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` ' 'None, we first check the `quantization_config` '
......
...@@ -3,13 +3,13 @@ from typing import Type ...@@ -3,13 +3,13 @@ from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.myq import MYQConfig from vllm.model_executor.layers.quantization.mygq import MYQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"myq": MYQConfig, "mygq": MYQConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
......
...@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig): ...@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "myq" return "mygq"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
...@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase): ...@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY weights["exllama_state"] = ExllamaState.READY
ops.myq_shuffle(weights["qweight"], weights["g_idx"], ops.mygq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits) self.quant_config.weight_bits)
output = ops.myq_gemm(reshaped_x, weights["qweight"], output = ops.mygq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"], weights["qzeros"], weights["scales"],
weights["g_idx"], weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY, weights["exllama_state"] == ExllamaState.READY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment