Unverified Commit 89cdaa83 authored by Szymon Ożóg's avatar Szymon Ożóg Committed by GitHub
Browse files

[Kernel] Add more dtype support for GGUF kernels (#14043)


Signed-off-by: default avatarSzymonOzog <szymon.ozog@aleph-alpha.com>
Signed-off-by: default avatarSzymonOzog <szymon.ozog@gmail.com>
parent b0746fae
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h"
#include "ggml-common.h" #include "ggml-common.h"
#include "vecdotq.cuh" #include "vecdotq.cuh"
...@@ -13,7 +14,8 @@ ...@@ -13,7 +14,8 @@
#include "mmq.cuh" #include "mmq.cuh"
// Q8 gemv // Q8 gemv
static __global__ void quantize_q8_1(const half* __restrict__ x, template <typename scalar_t>
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
void* __restrict__ vy, const int kx, void* __restrict__ vy, const int kx,
const int kx_padded) { const int kx_padded) {
const int ix = blockDim.x * blockIdx.x + threadIdx.x; const int ix = blockDim.x * blockIdx.x + threadIdx.x;
...@@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, ...@@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
const int ib = i_padded / QK8_1; // block index const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index const int iqs = i_padded % QK8_1; // quant index
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f; const float xi = ix < kx ? static_cast<float>(x[iy * kx + ix]) : 0.0f;
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
...@@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, ...@@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
y[ib].ds.y = __float2half(sum); y[ib].ds.y = __float2half(sum);
} }
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx, template <typename scalar_t>
static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
const int ky, cudaStream_t stream) { const int ky, cudaStream_t stream) {
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512; const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
const int block_num_x = const int block_num_x =
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ky, 1); const dim3 num_blocks(block_num_x, ky, 1);
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded); quantize_q8_1<scalar_t>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
} }
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
...@@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight ...@@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
int col = X.sizes()[1]; int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512; const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
at::Tensor Y = torch::empty({1, row}, options); at::Tensor Y = torch::empty({1, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options); at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1, VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
stream); quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
switch (type) { (void*)quant_X.data_ptr(), col, 1, stream);
case 2: switch (type) {
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), case 2:
(half*)Y.data_ptr(), col, row, stream); mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
break; (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
case 3: (scalar_t*)Y.data_ptr(), col, row, stream);
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), break;
(half*)Y.data_ptr(), col, row, stream); case 3:
break; mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
case 6: (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, stream);
(half*)Y.data_ptr(), col, row, stream); break;
break; case 6:
case 7: mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream); (scalar_t*)Y.data_ptr(), col, row, stream);
break; break;
case 8: case 7:
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 10: break;
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), case 8:
(half*)Y.data_ptr(), col, row, stream); mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
break; (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
case 11: (scalar_t*)Y.data_ptr(), col, row, stream);
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), break;
(half*)Y.data_ptr(), col, row, stream); case 10:
break; mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
case 12: (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, stream);
(half*)Y.data_ptr(), col, row, stream); break;
break; case 11:
case 13: mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream); (scalar_t*)Y.data_ptr(), col, row, stream);
break; break;
case 14: case 12:
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(), mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 16: break;
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(), case 13:
(void*)quant_X.data_ptr(), mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 17: break;
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(), case 14:
(void*)quant_X.data_ptr(), mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 18: break;
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(), case 16:
(void*)quant_X.data_ptr(), mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 19: break;
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(), case 17:
(void*)quant_X.data_ptr(), mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 20: break;
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(), case 18:
(void*)quant_X.data_ptr(), mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 21: break;
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(), case 19:
(void*)quant_X.data_ptr(), mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 22: break;
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(), case 20:
(void*)quant_X.data_ptr(), mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 23: break;
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(), case 21:
(void*)quant_X.data_ptr(), mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
case 29: break;
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(), case 22:
(void*)quant_X.data_ptr(), mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
(half*)Y.data_ptr(), col, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, stream);
} break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), col, row, stream);
break;
}
});
return Y; return Y;
} }
...@@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight ...@@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
int padded = (col + 512 - 1) / 512 * 512; int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0]; int batch = X.sizes()[0];
const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
at::Tensor Y = torch::empty({batch, row}, options); at::Tensor Y = torch::empty({batch, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options); at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
batch, stream); quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
col, batch, stream);
switch (type) {
case 2: switch (type) {
ggml_mul_mat_q4_0_q8_1_cuda( case 2:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q4_0_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 3: break;
ggml_mul_mat_q4_1_q8_1_cuda( case 3:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q4_1_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 6: break;
ggml_mul_mat_q5_0_q8_1_cuda( case 6:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q5_0_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 7: break;
ggml_mul_mat_q5_1_q8_1_cuda( case 7:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q5_1_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 8: break;
ggml_mul_mat_q8_0_q8_1_cuda( case 8:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q8_0_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 10: break;
ggml_mul_mat_q2_K_q8_1_cuda( case 10:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q2_K_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 11: break;
ggml_mul_mat_q3_K_q8_1_cuda( case 11:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q3_K_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 12: break;
ggml_mul_mat_q4_K_q8_1_cuda( case 12:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q4_K_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 13: break;
ggml_mul_mat_q5_K_q8_1_cuda( case 13:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q5_K_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
case 14: break;
ggml_mul_mat_q6_K_q8_1_cuda( case 14:
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), ggml_mul_mat_q6_K_q8_1_cuda(
col, row, batch, padded, row, stream); (void*)W.data_ptr(), (void*)quant_X.data_ptr(),
break; (scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
} break;
}
});
return Y; return Y;
} }
This diff is collapsed.
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) { static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) { if (row >= nrows) {
...@@ -33,158 +33,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * ...@@ -33,158 +33,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
} }
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dst[row] = __float2half(tmp); dst[row] = tmp;
} }
} }
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1> mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { template<typename scalar_t>
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1> mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
...@@ -22,7 +22,7 @@ def get_gguf_sample_tensors( ...@@ -22,7 +22,7 @@ def get_gguf_sample_tensors(
return GGUFReader(sample_file).tensors return GGUFReader(sample_file).tensors
DTYPES = [torch.half] DTYPES = [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo, # Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently. # we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024] HIDDEN_SIZES = [256, 1024]
...@@ -52,7 +52,7 @@ QUANT_TYPES = [ ...@@ -52,7 +52,7 @@ QUANT_TYPES = [
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", [torch.half])
@pytest.mark.parametrize("quant_type", QUANT_TYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode() @torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype, def test_dequantize(hidden_size: int, dtype: torch.dtype,
...@@ -122,7 +122,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, ...@@ -122,7 +122,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output = x @ weight.T ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda") qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
qweight.shape[0]).to(dtype) atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
# test matrix has inputs centered around 0 and lower precision from
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) # bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(output,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])
...@@ -436,7 +436,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"): ...@@ -436,7 +436,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
quant_type: int, quant_type: int,
row: torch.SymInt, row: torch.SymInt,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device) return torch.empty((1, row), dtype=X.dtype, device=W.device)
@register_fake("_C::ggml_mul_mat_a8") @register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake( def _ggml_mul_mat_a8_fake(
...@@ -446,7 +446,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"): ...@@ -446,7 +446,7 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
row: torch.SymInt, row: torch.SymInt,
) -> torch.Tensor: ) -> torch.Tensor:
batch = X.size(0) batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device) return torch.empty((batch, row), dtype=X.dtype, device=W.device)
# cutlass # cutlass
......
...@@ -32,7 +32,7 @@ class GGUFConfig(QuantizationConfig): ...@@ -32,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
return "gguf" return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half] return [torch.half, torch.bfloat16, torch.float32]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -134,6 +134,7 @@ class GGUFLinearMethod(LinearMethodBase): ...@@ -134,6 +134,7 @@ class GGUFLinearMethod(LinearMethodBase):
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
self.params_dtype = params_dtype
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition) tensor_shape = (output_size_per_partition, input_size_per_partition)
...@@ -326,7 +327,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): ...@@ -326,7 +327,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat = x.flatten() x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat) quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]) x_flat.shape[0]).to(self.params_dtype)
return dequant.view(*x.shape, hidden_size) return dequant.view(*x.shape, hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment