Unverified Commit 360bd67c authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Core] Support loading GGUF model (#5191)


Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent ef527be0
...@@ -30,6 +30,11 @@ jobs: ...@@ -30,6 +30,11 @@ jobs:
run: | run: |
EXCLUDES=( EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu' 'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
) )
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
......
...@@ -208,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -208,6 +208,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu" "csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu" "csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
......
...@@ -107,6 +107,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, ...@@ -107,6 +107,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits); int64_t size_n, int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m,
int64_t n);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type,
int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type,
int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n, int64_t num_bits, int64_t size_m, int64_t size_n,
......
This diff is collapsed.
This diff is collapsed.
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
// Q8 gemv
static __global__ void quantize_q8_1(const half* __restrict__ x,
void* __restrict__ vy, const int kx,
const int kx_padded) {
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
if (ix >= kx_padded) {
return;
}
const int iy = blockDim.y * blockIdx.y + threadIdx.y;
const int i_padded = iy * kx_padded + ix;
block_q8_1* y = (block_q8_1*)vy;
const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
float amax = fabsf(xi);
float sum = xi;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
}
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
y[ib].ds.x = __float2half(d);
y[ib].ds.y = __float2half(sum);
}
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
const int ky, cudaStream_t stream) {
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
const int block_num_x =
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ky, 1);
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int8_t type, int64_t m, int64_t n) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
at::Tensor DW = torch::empty({m, n}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
return DW;
}
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
at::Tensor Y = torch::empty({1, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
stream);
switch (type) {
case 2:
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 3:
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 6:
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 7:
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 8:
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 10:
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 11:
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 12:
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 13:
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 14:
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 16:
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 17:
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 18:
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 19:
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 20:
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 21:
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 22:
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(half*)Y.data_ptr(), col, row, stream);
break;
}
return Y;
}
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
torch::Tensor X, // input
int8_t type, int64_t row) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0];
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options =
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
at::Tensor Y = torch::empty({batch, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
batch, stream);
switch (type) {
case 2:
ggml_mul_mat_q4_0_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 3:
ggml_mul_mat_q4_1_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 6:
ggml_mul_mat_q5_0_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 7:
ggml_mul_mat_q5_1_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 8:
ggml_mul_mat_q8_0_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 10:
ggml_mul_mat_q2_K_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 11:
ggml_mul_mat_q3_K_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 12:
ggml_mul_mat_q4_K_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 13:
ggml_mul_mat_q5_K_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
case 14:
ggml_mul_mat_q6_K_q8_1_cuda(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
col, row, batch, padded, row, stream);
break;
}
return Y;
}
\ No newline at end of file
This diff is collapsed.
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (threadIdx.x == 0) {
dst[row] = __float2half(tmp);
}
}
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
This diff is collapsed.
...@@ -145,6 +145,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -145,6 +145,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("awq_marlin_repack", &awq_marlin_repack); ops.def("awq_marlin_repack", &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
// Dequantization for GGML.
ops.def("ggml_dequantize", &ggml_dequantize);
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
......
from huggingface_hub import hf_hub_download
from vllm import LLM, SamplingParams
def run_gguf_inference(model_path):
PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501
system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501
# Sample prompts.
prompts = [
"How many helicopters can a human eat in one sitting?",
"What's the future of AI?",
]
prompts = [
PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
for prompt in prompts
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, max_tokens=128)
# Create an LLM.
llm = LLM(model=model_path,
tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
gpu_memory_utilization=0.95)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
model = hf_hub_download(repo_id, filename=filename)
run_gguf_inference(model)
...@@ -242,6 +242,11 @@ echo 'vLLM isort: Done' ...@@ -242,6 +242,11 @@ echo 'vLLM isort: Done'
# NOTE: Keep up to date with .github/workflows/clang-format.yml # NOTE: Keep up to date with .github/workflows/clang-format.yml
CLANG_FORMAT_EXCLUDES=( CLANG_FORMAT_EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu' 'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
) )
# Format specified files with clang-format # Format specified files with clang-format
......
...@@ -22,3 +22,4 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 ...@@ -22,3 +22,4 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq pyzmq
gguf == 0.9.1
"""
Tests gguf models against unquantized models generations
Note: To pass the test, quantization higher than Q4 should be used
"""
import os
import pytest
from huggingface_hub import hf_hub_download
from tests.quantization.utils import is_quant_method_supported
from .utils import check_logprobs_close
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
# FIXME: Move this to confest
MODELS = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
]
@pytest.mark.skipif(not is_quant_method_supported("gguf"),
reason="gguf is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
original_model, gguf_model = model
# Run unquantized model.
with vllm_runner(model_name=original_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
# Run gguf model.
with vllm_runner(model_name=gguf_model,
dtype=dtype,
max_model_len=MAX_MODEL_LEN,
enforce_eager=True,
tensor_parallel_size=1) as gguf_model:
gguf_outputs = gguf_model.generate_greedy_logprobs(
example_prompts[:-1], max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)
...@@ -7,11 +7,12 @@ from typing import Tuple ...@@ -7,11 +7,12 @@ from typing import Tuple
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod) GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod)
PROMPT = "On the surface of Mars, we found" PROMPT = "On the surface of Mars, we found"
...@@ -37,7 +38,8 @@ def test_lm_head( ...@@ -37,7 +38,8 @@ def test_lm_head(
lm_head_layer.linear_method, lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod)) (GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else: else:
assert isinstance(lm_head_layer.linear_method, UnquantizedLinearMethod) assert isinstance(lm_head_layer.linear_method,
UnquantizedEmbeddingMethod)
print( print(
vllm_model.generate_greedy(prompts=["Hello my name is"], vllm_model.generate_greedy(prompts=["Hello my name is"],
......
...@@ -404,6 +404,38 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -404,6 +404,38 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace, size_m, size_n, size_k) workspace, size_m, size_n, size_k)
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
def ggml_mul_mat_vec(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)
def ggml_mul_mat_vec_a8(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
def ggml_mul_mat_a8(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
# moe # moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor, block_size: int, sorted_token_ids: torch.Tensor,
......
...@@ -582,6 +582,7 @@ class LoadFormat(str, enum.Enum): ...@@ -582,6 +582,7 @@ class LoadFormat(str, enum.Enum):
DUMMY = "dummy" DUMMY = "dummy"
TENSORIZER = "tensorizer" TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state" SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
......
...@@ -672,6 +672,9 @@ class EngineArgs: ...@@ -672,6 +672,9 @@ class EngineArgs:
return engine_args return engine_args
def create_engine_config(self, ) -> EngineConfig: def create_engine_config(self, ) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf"
# bitsandbytes quantization needs a specific model loader # bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent # so we make sure the quant method and the load format are consistent
......
...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -311,6 +311,17 @@ class ColumnParallelLinear(LinearBase): ...@@ -311,6 +311,17 @@ class ColumnParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if output_dim is not None: if output_dim is not None:
shard_size = param_data.shape[output_dim] shard_size = param_data.shape[output_dim]
...@@ -398,6 +409,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -398,6 +409,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None): loaded_shard_id: Optional[int] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
...@@ -460,6 +492,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -460,6 +492,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if is_gguf_weight:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
...@@ -563,6 +602,29 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -563,6 +602,29 @@ class QKVParallelLinear(ColumnParallelLinear):
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return
if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES
ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks. # Special case for AQLM codebooks.
...@@ -650,6 +712,13 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -650,6 +712,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_shard( shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = loaded_weight.shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
...@@ -755,6 +824,17 @@ class RowParallelLinear(LinearBase): ...@@ -755,6 +824,17 @@ class RowParallelLinear(LinearBase):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None) input_dim = getattr(param, "input_dim", None)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
param_data = param.data param_data = param.data
if input_dim is not None: if input_dim is not None:
shard_size = param_data.shape[input_dim] shard_size = param_data.shape[input_dim]
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig) DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
...@@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -31,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
# The order of gptq methods is important for config.py iteration over # The order of gptq methods is important for config.py iteration over
# override_quantization_method(..) # override_quantization_method(..)
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig, "awq_marlin": AWQMarlinConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment