Commit 5ec33d04 authored by Atream's avatar Atream
Browse files

optimize gguf dequant, save mem, support Q2_K

use marlin for lm_head, lm_head only calc last token for prefill
extend context window to 19K for DeepSeek-V3/R1 within 24GB VRAM
parent 7e1fe256
/** /**
* @Description : * @Description :
* @Author : Azure-Tang * @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30 * @Date : 2024-07-25 13:38:30
* @Version : 1.0.0 * @Version : 0.2.2
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
...@@ -19,22 +17,44 @@ ...@@ -19,22 +17,44 @@
// namespace py = pybind11; // namespace py = pybind11;
PYBIND11_MODULE(KTransformersOps, m) { PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); }, "Function to dequantize q8_0 data.",
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", }, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype")); }, "Function to dequantize q5_k data.",
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
} }
#include "ops.h"
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/torch.h>
// namespace py = pybind11;
int test(){
return 5;
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("test", &test, "Function to test.");
}
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
* @Description : * @Description :
* @Author : Azure-Tang, Boxin Zhang * @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30 * @Date : 2024-07-25 13:38:30
* @Version : 1.0.0 * @Version : 0.2.2
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 04:18:04
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
...@@ -18,45 +16,42 @@ ...@@ -18,45 +16,42 @@
#include <cstdint> #include <cstdint>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size; const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block)); float scale = __half2float(*((half*)cur_block));
cur_block += 2; cur_block += 2;
for (int i = 0; i < 32; i++){ for (int i = 0; i < ele_per_blk; i++){
output_blk[i] = scale * cur_block[i]; output_blk[i] = scale * cur_block[i];
} }
output_blk += 32;
} }
} }
__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size; const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block)); float scale = __half2float(*((half*)cur_block));
cur_block += 2; cur_block += 2;
for (int i = 0; i < 32; i++) { for (int i = 0; i < ele_per_blk; i++) {
output_blk[i] = __float2half(scale * cur_block[i]); output_blk[i] = __float2half(scale * cur_block[i]);
} }
output_blk += 32;
} }
} }
__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) { for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const int8_t* cur_block = data + block_id * blk_size; const int8_t* cur_block = data + block_id * blk_size;
float scale = __half2float(*((half*)cur_block)); float scale = __half2float(*((half*)cur_block));
cur_block += 2; cur_block += 2;
for (int i = 0; i < 32; i++) { for (int i = 0; i < ele_per_blk; i++) {
output_blk[i] = __float2bfloat16(scale * cur_block[i]); output_blk[i] = __float2bfloat16(scale * cur_block[i]);
} }
output_blk += 32;
} }
} }
...@@ -70,10 +65,10 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_ ...@@ -70,10 +65,10 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
} }
} }
__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
...@@ -104,10 +99,10 @@ __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -104,10 +99,10 @@ __global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, c
} }
} }
__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
...@@ -138,10 +133,10 @@ __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -138,10 +133,10 @@ __global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output,
} }
} }
__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));
...@@ -172,13 +167,13 @@ __global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -172,13 +167,13 @@ __global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
} }
} }
__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f; const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
uint32_t aux[4]; uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux; const int8_t * scales = (const int8_t*)aux;
...@@ -228,13 +223,13 @@ __global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -228,13 +223,13 @@ __global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, c
} }
} }
__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f; const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
uint32_t aux[4]; uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux; const int8_t * scales = (const int8_t*)aux;
...@@ -284,13 +279,13 @@ __global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -284,13 +279,13 @@ __global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output,
} }
} }
__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303; const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f; const uint32_t kmask2 = 0x0f0f0f0f;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
uint32_t aux[4]; uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux; const int8_t * scales = (const int8_t*)aux;
...@@ -341,10 +336,10 @@ __global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -341,10 +336,10 @@ __global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
} }
__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs; // const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16); const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
...@@ -352,7 +347,7 @@ __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -352,7 +347,7 @@ __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, c
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0; int is = 0;
uint8_t sc, m; uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) { for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m); get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m; const float d1 = d * sc; const float m1 = min * m;
...@@ -365,10 +360,10 @@ __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -365,10 +360,10 @@ __global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, c
} }
} }
__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs; // const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16); const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
...@@ -376,7 +371,7 @@ __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -376,7 +371,7 @@ __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output,
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0; int is = 0;
uint8_t sc, m; uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) { for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m); get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m; const float d1 = d * sc; const float m1 = min * m;
...@@ -389,10 +384,10 @@ __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -389,10 +384,10 @@ __global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output,
} }
} }
__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
// const uint8_t * q = data[i].qs; // const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16); const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
...@@ -400,7 +395,7 @@ __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -400,7 +395,7 @@ __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));
int is = 0; int is = 0;
uint8_t sc, m; uint8_t sc, m;
for (int j = 0; j < blk_size; j += 64) { for (int j = 0; j < ele_per_blk; j += 64) {
uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4); uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);
get_scale_min_k4(is + 0, scales, &sc, &m); get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m; const float d1 = d * sc; const float m1 = min * m;
...@@ -413,10 +408,10 @@ __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -413,10 +408,10 @@ __global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
} }
} }
__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
...@@ -442,10 +437,10 @@ __global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -442,10 +437,10 @@ __global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, c
} }
} }
__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
...@@ -471,10 +466,10 @@ __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -471,10 +466,10 @@ __global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output,
} }
} }
__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){ for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2))); const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));
...@@ -500,10 +495,10 @@ __global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -500,10 +495,10 @@ __global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
} }
} }
__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
...@@ -511,31 +506,30 @@ __global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, c ...@@ -511,31 +506,30 @@ __global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, c
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){ for (int n = 0; n < ele_per_blk; n += 128) {
for (int n = 0; n < blk_size; n += 128) { for (int l = 0; l < 32; ++l) {
for (int l = 0; l < 32; ++l) { int is = l/16;
int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = d * sc[is + 0] * q1;
output_blk[l + 0] = d * sc[is + 0] * q1; output_blk[l + 32] = d * sc[is + 2] * q2;
output_blk[l + 32] = d * sc[is + 2] * q2; output_blk[l + 64] = d * sc[is + 4] * q3;
output_blk[l + 64] = d * sc[is + 4] * q3; output_blk[l + 96] = d * sc[is + 6] * q4;
output_blk[l + 96] = d * sc[is + 6] * q4;
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
} }
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
} }
} }
__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
...@@ -543,31 +537,30 @@ __global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, ...@@ -543,31 +537,30 @@ __global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output,
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){ for (int n = 0; n < ele_per_blk; n += 128) {
for (int n = 0; n < blk_size; n += 128) { for (int l = 0; l < 32; ++l) {
for (int l = 0; l < 32; ++l) { int is = l/16;
int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = __float2half(d * sc[is + 0] * q1);
output_blk[l + 0] = __float2half(d * sc[is + 0] * q1); output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);
output_blk[l + 32] = __float2half(d * sc[is + 2] * q2); output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);
output_blk[l + 64] = __float2half(d * sc[is + 4] * q3); output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);
output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
} }
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
} }
} }
__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){ for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size); const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);
...@@ -575,33 +568,32 @@ __global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* out ...@@ -575,33 +568,32 @@ __global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* out
const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192); const int8_t * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);
//if (blk_size == 256){ for (int n = 0; n < ele_per_blk; n += 128) {
for (int n = 0; n < blk_size; n += 128) { for (int l = 0; l < 32; ++l) {
for (int l = 0; l < 32; ++l) { int is = l/16;
int is = l/16; const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1);
output_blk[l + 0] = __float2bfloat16(d * sc[is + 0] * q1); output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);
output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2); output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);
output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3); output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);
output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);
}
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
} }
output_blk += 128;
ql += 64;
qh += 32;
sc += 8;
}
} }
} }
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int num_blocks) { __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) { for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256); float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2)); const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
...@@ -620,10 +612,10 @@ __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, ...@@ -620,10 +612,10 @@ __global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output,
} }
} }
__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int num_blocks) { __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) { for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
__half* __restrict__ output_blk = (__half*)(output + block_id * 256); __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2)); const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
...@@ -642,10 +634,10 @@ __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output ...@@ -642,10 +634,10 @@ __global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output
} }
} }
__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int num_blocks) { __global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x; long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) { for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * 256); nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);
const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size))); const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2)); const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));
const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2); const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);
...@@ -664,7 +656,7 @@ __global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* o ...@@ -664,7 +656,7 @@ __global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* o
} }
} }
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -679,13 +671,13 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int ...@@ -679,13 +671,13 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -697,7 +689,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int ...@@ -697,7 +689,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
} }
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
...@@ -713,13 +705,13 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int ...@@ -713,13 +705,13 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -729,7 +721,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int ...@@ -729,7 +721,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
return output; return output;
} }
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -744,13 +736,13 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int ...@@ -744,13 +736,13 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -760,7 +752,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int ...@@ -760,7 +752,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
return output; return output;
} }
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
// data.numel%blk_size should be 0, else raise err // data.numel%blk_size should be 0, else raise err
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -776,13 +768,13 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int ...@@ -776,13 +768,13 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -792,7 +784,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int ...@@ -792,7 +784,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
return output; return output;
} }
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -807,13 +799,13 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int ...@@ -807,13 +799,13 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -823,7 +815,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int ...@@ -823,7 +815,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
return output; return output;
} }
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -838,13 +830,13 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int ...@@ -838,13 +830,13 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
...@@ -854,7 +846,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int ...@@ -854,7 +846,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
return output; return output;
} }
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype) { torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {
int num_blocks = num_bytes / blk_size; int num_blocks = num_bytes / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device); const at::cuda::OptionalCUDAGuard device_guard(device);
...@@ -869,13 +861,13 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i ...@@ -869,13 +861,13 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
switch (target_dtype) { switch (target_dtype) {
case torch::kFloat16: case torch::kFloat16:
dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, num_blocks); dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<nv_bfloat16>(), blk_size, num_blocks); dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);
break; break;
case torch::kFloat32: case torch::kFloat32:
dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks); dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);
break; break;
default: default:
printf("target type not support\n"); printf("target type not support\n");
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype); torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
import os
import sys
sys.path.insert(0,"/home/zbx/ktransformers")
from ktransformers.util.custom_gguf import GGUFLoader
import torch
gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/")
torch.set_default_dtype(torch.bfloat16)
tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
print(tensor_1[0, -64:])
print(tensor_2[0, -64:])
\ No newline at end of file
...@@ -90,7 +90,7 @@ def marlin_quantize( ...@@ -90,7 +90,7 @@ def marlin_quantize(
assert group_size <= size_k assert group_size <= size_k
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order) act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
...@@ -107,7 +107,7 @@ def marlin_quantize( ...@@ -107,7 +107,7 @@ def marlin_quantize(
marlin_scale_perm_single[num_bits]) marlin_scale_perm_single[num_bits])
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)): for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device) res_list[i] = res_list[i].to(w.device)
......
...@@ -11,8 +11,7 @@ def get_pack_factor(num_bits): ...@@ -11,8 +11,7 @@ def get_pack_factor(num_bits):
return 32 // num_bits return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def permute_rows(q_w: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
orig_device = q_w.device orig_device = q_w.device
k_size, _ = q_w.shape k_size, _ = q_w.shape
...@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ...@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx = g_idx[rand_perm].contiguous() g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous() q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return ( return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device), q_w.to(device=orig_device),
g_idx.to(device=orig_device), g_idx.to(device=orig_device),
rand_perm.to(device=orig_device), rand_perm.to(device=orig_device),
...@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
q_w += half_q_val q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val) q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes # Restore original shapes
if group_size < size_k: if group_size < size_k:
...@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
return w return w
q_w = reshape_w(q_w) q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous() s = s.reshape((-1, size_n)).contiguous()
...@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
), "For act_order, groupsize = {} must be less than size_k = {}".format( ), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k) group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) q_w, g_idx, rand_perm = permute_rows(q_w, group_size)
return ( return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device), q_w.to(device=orig_device),
s.to(device=orig_device), s.to(device=orig_device),
g_idx.to(device=orig_device), g_idx.to(device=orig_device),
......
...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states[:,-1:,:]).float()
logits = logits[:,-1,:].unsqueeze(0).float()
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float() logits = logits.float()
loss = None loss = None
......
...@@ -9,7 +9,7 @@ flashinfer_enabled = False ...@@ -9,7 +9,7 @@ flashinfer_enabled = False
try: try:
import flashinfer import flashinfer
flashinfer_enabled = True flashinfer_enabled = False
print("found flashinfer") print("found flashinfer")
except ImportError: except ImportError:
......
...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -21,6 +21,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
MarlinWorkspace, MarlinWorkspace,
marlin_quantize, marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MAX_PARALLEL,
) )
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
...@@ -64,6 +65,8 @@ class KLinearBase(ABC): ...@@ -64,6 +65,8 @@ class KLinearBase(ABC):
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0] self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1] self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.
@abstractmethod @abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
pass pass
...@@ -134,6 +137,7 @@ class KLinearTorch(KLinearBase): ...@@ -134,6 +137,7 @@ class KLinearTorch(KLinearBase):
return x return x
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
if w is None: w = self.load_weight(device=device) if w is None: w = self.load_weight(device=device)
# else: self.out_features = w.shape[0], self.in_features = w.shape[1] # else: self.out_features = w.shape[0], self.in_features = w.shape[1]
...@@ -157,6 +161,7 @@ class KLinearTorch(KLinearBase): ...@@ -157,6 +161,7 @@ class KLinearTorch(KLinearBase):
self.weight = self.weight.to(device) self.weight = self.weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
self.loaded = True
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
...@@ -190,20 +195,36 @@ class KLinearMarlin(KLinearBase): ...@@ -190,20 +195,36 @@ class KLinearMarlin(KLinearBase):
self.group_size = group_size self.group_size = group_size
self.act_order = act_order self.act_order = act_order
self.is_k_full = is_k_full self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device" assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None: if w is None:
w = self.load_weight(device=device) w = self.load_weight(device=device)
if isinstance(w, nn.Parameter): if isinstance(w, nn.Parameter):
# pad weight # pad weight
weight = w.view(self.out_features, self.in_features).T weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False self.has_bias = False
elif isinstance(w, tuple): elif isinstance(w, tuple):
w = list(w) w = list(w)
weight = w[0].view(self.out_features, self.in_features).T weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1] self.bias = w[1]
self.has_bias = True self.has_bias = True
else: else:
...@@ -211,8 +232,14 @@ class KLinearMarlin(KLinearBase): ...@@ -211,8 +232,14 @@ class KLinearMarlin(KLinearBase):
weight = weight.to(device) weight = weight.to(device)
if self.has_bias: if self.has_bias:
self.bias = self.bias.to(device) self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear # Pack Marlin linear
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order weight, self.num_bits, self.group_size, self.act_order
) )
self.workspace = MarlinWorkspace( self.workspace = MarlinWorkspace(
...@@ -225,6 +252,7 @@ class KLinearMarlin(KLinearBase): ...@@ -225,6 +252,7 @@ class KLinearMarlin(KLinearBase):
self.sort_indices = sort_indices self.sort_indices = sort_indices
self.k = weight.shape[0] self.k = weight.shape[0]
self.n = weight.shape[1] self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# Only support input x as BF16 and FP16 # Only support input x as BF16 and FP16
...@@ -232,6 +260,11 @@ class KLinearMarlin(KLinearBase): ...@@ -232,6 +260,11 @@ class KLinearMarlin(KLinearBase):
orig_shape = list(x.shape) orig_shape = list(x.shape)
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1]) x = x.reshape(-1, orig_shape[-1])
x = x.reshape(-1, x.shape[-1])
if self.padding:
padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)
padding_input[:,:self.orin_in_features] = x
x = padding_input
marlin_s = self.marlin_s.to(x.dtype) marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm( x = KTransformersOps.gptq_marlin_gemm(
x, x,
...@@ -246,6 +279,11 @@ class KLinearMarlin(KLinearBase): ...@@ -246,6 +279,11 @@ class KLinearMarlin(KLinearBase):
x.shape[-1], x.shape[-1],
self.is_k_full, self.is_k_full,
) )
if self.padding:
x = x[:,:self.orin_out_features]
orig_shape[-1] = self.orin_out_features
else:
orig_shape[-1] = self.out_features
if self.has_bias: if self.has_bias:
x = x + self.bias x = x + self.bias
orig_shape[-1] = self.n orig_shape[-1] = self.n
...@@ -388,24 +426,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase): ...@@ -388,24 +426,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:
assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported" assert prefill_op in LINEAR_MAP, f"linear_type {prefill_op} not supported"
if prefill_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.prefill_linear = KLinearTorch(key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else:
self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)
else: else:
self.prefill_linear = None self.prefill_linear = None
if generate_op is not None: if generate_op is not None:
assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported" assert generate_op in LINEAR_MAP, f"linear_type {generate_op} not supported"
if generate_op == "KLinearMarlin" and (orig_module.in_features%GPTQ_MARLIN_MIN_THREAD_N!=0 or orig_module.out_features%GPTQ_MARLIN_MIN_THREAD_N!=0): self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using KLinearTorch instead.")
print(f"module info: key:{key} orig_module:{orig_module}")
self.generate_op = "KLinearTorch"
self.generate_linear = KLinearTorch(key, gguf_loader, config, orig_module, generate_device, **kwargs)
else:
self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)
else: else:
self.generate_linear = None self.generate_linear = None
self.mode = InferenceState.UNLOAD self.mode = InferenceState.UNLOAD
......
...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo ...@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
gguf_loader=GGUFLoader(gguf_path) gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
......
...@@ -219,8 +219,20 @@ ...@@ -219,8 +219,20 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)" name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -118,7 +118,18 @@ ...@@ -118,7 +118,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,18 @@ ...@@ -15,6 +15,18 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
...@@ -188,7 +188,7 @@ ...@@ -188,7 +188,7 @@
# !!!Do remember 'close' cuda graph if you are using marlin expert.!!! # !!!Do remember 'close' cuda graph if you are using marlin expert.!!!
# !!!KExpertsTorch is untested, we don't have enough VRAM.!!! # !!!KExpertsTorch is untested, we don't have enough VRAM.!!!
# # GPU 0: layers 3–4 # GPU 0: layers 3–4
# - match: # - match:
# name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$" # name: "^model\\.layers\\.([3-4])\\.mlp\\.experts$"
# replace: # replace:
...@@ -363,11 +363,20 @@ ...@@ -363,11 +363,20 @@
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 3 (as in your original config) # For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -713,11 +713,20 @@ ...@@ -713,11 +713,20 @@
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"
# don't inject lm_head if already inject marlin experts - match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
# For final modules (model.norm and lm_head), ensure they are on GPU 7 (as in your original config) # For final modules (model.norm), ensure they are on GPU 7 (as in your original config)
- match: - match:
name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)|(^lm_head)" name: "(^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.)|(^model\\.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -153,7 +153,18 @@ ...@@ -153,7 +153,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment