Unverified Commit 8fdcd98e authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[7/n] decouple quantization impl from vllm dependency - gguf kernel (#11019)

parent b5dcfd41
......@@ -271,6 +271,8 @@ set(SOURCES
"csrc/elementwise/topk.cu"
"csrc/common_extension.cc"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
......@@ -306,6 +308,7 @@ set(SOURCES
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_sum.cu"
"csrc/moe/moe_sum_reduce.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
......
......@@ -114,6 +114,37 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"cu_seqlens_q) -> ()");
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
/*
* From gguf quantiztion
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
/*
* From csrc/gemm
*/
......@@ -226,17 +257,23 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce);
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <ATen/cuda/Atomic.cuh>
#include <cub/cub.cuh>
#include "utils.h"
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += SGLANG_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
void moe_sum(
torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const auto num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (topk) {
case 2:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 2>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
case 3:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 3>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
case 4:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 4>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/dequantize.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/convert.cu
// Dequant functions
static __device__ __forceinline__ void dequantize_q4_0(const void* vx, const int ib, const int iqs, dfloat2& v) {
const block_q4_0* x = (const block_q4_0*)vx;
const dfloat d = x[ib].d;
const int vui = x[ib].qs[iqs];
v.x = __int2half_rn(vui & 0xF);
v.y = __int2half_rn(vui >> 4);
v = __hsub2(v, __floats2half2_rn(8.0f, 8.0f));
v = __hmul2(v, {d, d});
}
static __device__ __forceinline__ void dequantize_q4_1(const void* vx, const int ib, const int iqs, dfloat2& v) {
const block_q4_1* x = (const block_q4_1*)vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs];
v.x = __int2half_rn(vui & 0xF);
v.y = __int2half_rn(vui >> 4);
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
}
static __device__ __forceinline__ void dequantize_q5_0(const void* vx, const int ib, const int iqs, dfloat2& v) {
const block_q5_0* x = (const block_q5_0*)vx;
const dfloat d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12))) & 0x10;
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
v = __hsub2(v, __floats2half2_rn(16.0f, 16.0f));
v = __hmul2(v, {d, d});
}
static __device__ __forceinline__ void dequantize_q5_1(const void* vx, const int ib, const int iqs, dfloat2& v) {
const block_q5_1* x = (const block_q5_1*)vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12))) & 0x10;
v.x = __int2half_rn((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = __int2half_rn((x[ib].qs[iqs] >> 4) | xh_1);
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
}
static __device__ __forceinline__ void dequantize_q8_0(const void* vx, const int ib, const int iqs, dfloat2& v) {
const block_q8_0* x = (const block_q8_0*)vx;
const dfloat d = x[ib].d;
v.x = __int2half_rn(x[ib].qs[iqs + 0]);
v.y = __int2half_rn(x[ib].qs[iqs + 1]);
v = __hmul2(v, {d, d});
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void* __restrict__ vx, dst_t* __restrict__ y, const int k) {
const int i = 2 * (blockDim.x * blockIdx.x + threadIdx.x);
if (i >= k) {
return;
}
const int ib = i / qk; // block index
const int iqs = (i % qk) / qr; // quant index
const int iybs = i - i % qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk / 2;
// dequantize
dfloat2 v;
dequantize_kernel(vx, ib, iqs, v);
y[iybs + iqs + 0] = convert_from_half<dst_t>(v.x);
y[iybs + iqs + y_offset] = convert_from_half<dst_t>(v.y);
}
template <typename dst_t>
static __global__ void dequantize_block_q2_K(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_q2_K* x = (const block_q2_K*)vx;
const auto tid = threadIdx.x;
const int n = tid / 32;
const int l = tid - 32 * n;
const int is = 8 * n + l / 16;
const uint8_t q = x[i].qs[32 * n + l];
dst_t* y = yy + i * QK_K + 128 * n;
half dall = __low2half(x[i].dm);
half dmin = __high2half(x[i].dm);
y[l + 0] = convert_from_half<dst_t>(__hsub(
__hmul(dall, __int2half_rn((x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3))),
__hmul(dmin, __int2half_rn(x[i].scales[is + 0] >> 4))));
y[l + 32] = convert_from_half<dst_t>(__hsub(
__hmul(dall, __int2half_rn((x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3))),
__hmul(dmin, __int2half_rn(x[i].scales[is + 2] >> 4))));
y[l + 64] = convert_from_half<dst_t>(__hsub(
__hmul(dall, __int2half_rn((x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3))),
__hmul(dmin, __int2half_rn(x[i].scales[is + 4] >> 4))));
y[l + 96] = convert_from_half<dst_t>(__hsub(
__hmul(dall, __int2half_rn((x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3))),
__hmul(dmin, __int2half_rn(x[i].scales[is + 6] >> 4))));
}
template <typename dst_t>
static __global__ void dequantize_block_q3_K(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_q3_K* x = (const block_q3_K*)vx;
const auto r = threadIdx.x / 4;
const int tid = r / 2;
const int is0 = r % 2;
const int l0 = 16 * is0 + 4 * (threadIdx.x % 4);
const int n = tid / 4;
const int j = tid - 4 * n;
uint8_t m = 1 << (4 * n + j);
int is = 8 * n + 2 * j + is0;
int shift = 2 * j;
int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
: is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
: is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
: (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
half d_all = x[i].d;
half dl = __hmul(d_all, __int2half_rn(us - 32));
dst_t* y = yy + i * QK_K + 128 * n + 32 * j;
const uint8_t* q = x[i].qs + 32 * n;
const uint8_t* hm = x[i].hmask;
for (int l = l0; l < l0 + 4; ++l) {
y[l] = convert_from_half<dst_t>(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
}
}
static inline __device__ void get_scale_min_k4(int j, const uint8_t* q, uint8_t& d, uint8_t& m) {
if (j < 4) {
d = q[j] & 63;
m = q[j + 4] & 63;
} else {
d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
}
}
template <typename dst_t>
static __global__ void dequantize_block_q4_K(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const block_q4_K* x = (const block_q4_K*)vx;
const auto i = blockIdx.x;
// assume 32 threads
const auto tid = threadIdx.x;
const int il = tid / 8;
const int ir = tid % 8;
const int is = 2 * il;
const int n = 4;
dst_t* y = yy + i * QK_K + 64 * il + n * ir;
const half dall = __low2half(x[i].dm);
const half dmin = __high2half(x[i].dm);
const uint8_t* q = x[i].qs + 32 * il + n * ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const half d1 = __hmul(dall, __int2half_rn(sc));
const half m1 = __hmul(dmin, __int2half_rn(m));
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const half d2 = __hmul(dall, __int2half_rn(sc));
const half m2 = __hmul(dmin, __int2half_rn(m));
for (int l = 0; l < n; ++l) {
y[l + 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
y[l + 32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
}
}
template <typename dst_t>
static __global__ void dequantize_block_q5_K(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const block_q5_K* x = (const block_q5_K*)vx;
const auto i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const auto tid = threadIdx.x;
const int il = tid / 16; // il is in 0...3
const int ir = tid % 16; // ir is in 0...15
const int is = 2 * il; // is is in 0...6
dst_t* y = yy + i * QK_K + 64 * il + 2 * ir;
const half dall = __low2half(x[i].dm);
const half dmin = __high2half(x[i].dm);
const uint8_t* ql = x[i].qs + 32 * il + 2 * ir;
const uint8_t* qh = x[i].qh + 2 * ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const half d1 = __hmul(dall, __int2half_rn(sc));
const half m1 = __hmul(dmin, __int2half_rn(m));
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const half d2 = __hmul(dall, __int2half_rn(sc));
const half m2 = __hmul(dmin, __int2half_rn(m));
uint8_t hm = 1 << (2 * il);
y[0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
y[1] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
hm <<= 1;
y[32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
y[33] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
}
template <typename dst_t>
static __global__ void dequantize_block_q6_K(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const block_q6_K* x = (const block_q6_K*)vx;
const auto i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const auto tid = threadIdx.x;
const int ip = tid / 32; // ip is 0 or 1
const int il = tid - 32 * ip; // 0...32
const int is = 8 * ip + il / 16;
dst_t* y = yy + i * QK_K + 128 * ip + il;
const half d = x[i].d;
const uint8_t* ql = x[i].ql + 64 * ip + il;
const uint8_t qh = x[i].qh[32 * ip + il];
const int8_t* sc = x[i].scales + is;
y[0] = convert_from_half<dst_t>(
__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
y[32] = convert_from_half<dst_t>(
__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
y[64] = convert_from_half<dst_t>(
__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
y[96] = convert_from_half<dst_t>(
__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
}
template <typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq2_xxs* x = (const block_iq2_xxs*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint16_t* q2 = x[i].qs + 4 * ib;
const uint8_t* aux8 = (const uint8_t*)q2;
const uint8_t* grid = (const uint8_t*)(iq2xxs_grid + aux8[il]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7 * il) & 127];
for (int j = 0; j < 8; ++j)
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template <typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq2_xs* x = (const block_iq2_xs*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint16_t* q2 = x[i].qs + 4 * ib;
const uint8_t* grid = (const uint8_t*)(iq2xs_grid + (q2[il] & 511));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4 * (il / 2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j)
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template <typename dst_t>
static __global__ void dequantize_block_iq2_s(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq2_s* x = (const block_iq2_s*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint8_t* grid = (const uint8_t*)(iq2s_grid + (x[i].qs[4 * ib + il] | ((x[i].qh[ib] << (8 - 2 * il)) & 0x300)));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4 * (il / 2)) & 0xf)) * 0.25f;
const uint8_t signs = x[i].qs[QK_K / 8 + 4 * ib + il];
for (int j = 0; j < 8; ++j)
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
template <typename dst_t>
static __global__ void dequantize_block_iq3_xxs(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq3_xxs* x = (const block_iq3_xxs*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint8_t* q3 = x[i].qs + 8 * ib;
const uint16_t* gas = (const uint16_t*)(x[i].qs + QK_K / 4) + 2 * ib;
const uint8_t* grid1 = (const uint8_t*)(iq3xxs_grid + q3[2 * il + 0]);
const uint8_t* grid2 = (const uint8_t*)(iq3xxs_grid + q3[2 * il + 1]);
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7 * il) & 127];
for (int j = 0; j < 4; ++j) {
y[j + 0] = d * grid1[j] * (signs & kmask_iq2xs[j + 0] ? -1.f : 1.f);
y[j + 4] = d * grid2[j] * (signs & kmask_iq2xs[j + 4] ? -1.f : 1.f);
}
}
template <typename dst_t>
static __global__ void dequantize_block_iq3_s(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq3_s* x = (const block_iq3_s*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint8_t* qs = x[i].qs + 8 * ib;
const uint8_t* grid1 = (const uint8_t*)(iq3xs_grid + (qs[2 * il + 0] | ((x[i].qh[ib] << (8 - 2 * il)) & 256)));
const uint8_t* grid2 = (const uint8_t*)(iq3xs_grid + (qs[2 * il + 1] | ((x[i].qh[ib] << (7 - 2 * il)) & 256)));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib / 2] >> 4 * (ib % 2)) & 0xf)) * 0.5f;
const uint8_t signs = x[i].signs[4 * ib + il];
for (int j = 0; j < 4; ++j) {
y[j + 0] = d * grid1[j] * (signs & kmask_iq2xs[j + 0] ? -1.f : 1.f);
y[j + 4] = d * grid2[j] * (signs & kmask_iq2xs[j + 4] ? -1.f : 1.f);
}
}
template <typename dst_t>
static __global__ void dequantize_block_iq1_s(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq1_s* x = (const block_iq1_s*)vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid / 8; // 0...3
const int64_t ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = __half2float(x[i].d) * (2 * ((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2];
const int8_t* q = (const int8_t*)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4 * ib + il] | (((x[i].qh[ib] >> 3 * il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
}
template <typename dst_t>
static __global__ void dequantize_block_iq1_m(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const int64_t i = blockIdx.x;
const block_iq1_m* x = (const block_iq1_m*)vx;
const int64_t tid = threadIdx.x;
const int64_t il = tid / 8; // 0...3
const int64_t ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 8 * il;
const uint16_t* sc = (const uint16_t*)x[i].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const int64_t ib16 = 2 * ib + il / 2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
const float d = __half2float(scale.f16) * (2 * ((sc[ib16 / 4] >> 3 * (ib16 % 4)) & 0x7) + 1);
const float delta = x[i].qh[2 * ib + il / 2] & (0x08 << 4 * (il % 2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
uint32_t grid32[2];
const int8_t* q = (const int8_t*)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4 * ib + il] | (((x[i].qh[2 * ib + il / 2] >> 4 * (il % 2)) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
}
template <typename dst_t>
static __global__ void dequantize_block_iq4_nl(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq4_nl* x = (const block_iq4_nl*)vx + i * (QK_K / QK4_NL);
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 4 * il;
const uint8_t* q4 = x[ib].qs + 4 * il;
const float d = __half2float(x[ib].d);
for (int j = 0; j < 4; ++j) {
y[j + 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j + 16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
template <typename dst_t>
static __global__ void dequantize_block_iq4_xs(const void* __restrict__ vx, dst_t* __restrict__ yy) {
const auto i = blockIdx.x;
const block_iq4_xs* x = (const block_iq4_xs*)vx;
const auto tid = threadIdx.x;
const int il = tid / 8; // 0...3
const int ib = tid % 8; // 0...7
dst_t* y = yy + i * QK_K + 32 * ib + 4 * il;
const uint8_t* q4 = x[i].qs + 16 * ib + 4 * il;
const float d = __half2float(x[i].d) *
((((x[i].scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((x[i].scales_h >> 2 * ib) & 3) << 4)) - 32);
for (int j = 0; j < 4; ++j) {
y[j + 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j + 16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void
dequantize_block_cuda(const void* __restrict__ vx, dst_t* __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
template <typename dst_t>
static void dequantize_row_q2_K_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_q3_K_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_q4_K_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_q5_K_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_q6_K_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq2_s_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq3_s_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq1_s_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq1_m_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void* vx, dst_t* y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static to_cuda_ggml_t<dst_t> ggml_get_to_cuda(int64_t type) {
switch (type) {
case 2:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case 3:
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
case 6:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case 7:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case 8:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case 10:
return dequantize_row_q2_K_cuda;
case 11:
return dequantize_row_q3_K_cuda;
case 12:
return dequantize_row_q4_K_cuda;
case 13:
return dequantize_row_q5_K_cuda;
case 14:
return dequantize_row_q6_K_cuda;
case 16:
return dequantize_row_iq2_xxs_cuda;
case 17:
return dequantize_row_iq2_xs_cuda;
case 18:
return dequantize_row_iq3_xxs_cuda;
case 19:
return dequantize_row_iq1_s_cuda;
case 20:
return dequantize_row_iq4_nl_cuda;
case 21:
return dequantize_row_iq3_s_cuda;
case 22:
return dequantize_row_iq2_s_cuda;
case 23:
return dequantize_row_iq4_xs_cuda;
case 29:
return dequantize_row_iq1_m_cuda;
default:
return nullptr;
}
}
// adapted from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/ggml-common.h
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF 32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define GGML_CUDA_DMMV_X 32
#define GGML_CUDA_MMV_Y 1
// Data Structures
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization
#define QK4_0 32
#define QR4_0 2
#define QI4_0 (QK4_0 / (4 * QR4_0))
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
#define QK4_1 32
#define QR4_1 2
#define QI4_1 (QK4_1 / (4 * QR4_1))
typedef struct {
half2 dm; // dm.x = delta, dm.y = min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
#define QK5_0 32
#define QR5_0 2
#define QI5_0 (QK5_0 / (4 * QR5_0))
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
#define QK5_1 32
#define QR5_1 2
#define QI5_1 (QK5_1 / (4 * QR5_1))
typedef struct {
half2 dm; // dm.x = delta, dm.y = min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
#define QK8_0 32
#define QR8_0 1
#define QI8_0 (QK8_0 / (4 * QR8_0))
typedef struct {
half d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
#define QK8_1 32
#define QR8_1 1
#define QI8_1 (QK8_1 / (4 * QR8_1))
typedef struct {
half2 ds; // ds.x = delta, ds.y = sum
int8_t qs[QK8_0]; // quants
} block_q8_1;
#define QR2_K 4
#define QI2_K (QK_K / (4 * QR2_K))
typedef struct {
uint8_t scales[QK_K / 16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K / 4]; // quants
half2 dm; // super-block scale for quantized scales/mins
} block_q2_K;
#define QR3_K 4
#define QI3_K (QK_K / (4 * QR3_K))
typedef struct {
uint8_t hmask[QK_K / 8]; // quants - high bit
uint8_t qs[QK_K / 4]; // quants - low 2 bits
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
half d; // super-block scale
} block_q3_K;
#define QR4_K 2
#define QI4_K (QK_K / (4 * QR4_K))
typedef struct {
half2 dm; // super-block scale for quantized scales/mins
uint8_t scales[3 * QK_K / 64]; // scales, quantized with 6 bits
uint8_t qs[QK_K / 2]; // 4--bit quants
} block_q4_K;
#define QR5_K 2
#define QI5_K (QK_K / (4 * QR5_K))
typedef struct {
half2 dm; // super-block scale for quantized scales/mins
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K / 8]; // quants, high bit
uint8_t qs[QK_K / 2]; // quants, low 4 bits
} block_q5_K;
#define QR6_K 2
#define QI6_K (QK_K / (4 * QR6_K))
typedef struct {
uint8_t ql[QK_K / 2]; // quants, lower 4 bits
uint8_t qh[QK_K / 4]; // quants, upper 2 bits
int8_t scales[QK_K / 16]; // scales
half d; // delta
} block_q6_K;
#define QR2_XXS 8
#define QI2_XXS (QK_K / (4 * QR2_XXS))
typedef struct {
half d;
uint16_t qs[QK_K / 8];
} block_iq2_xxs;
#define QR2_XS 8
#define QI2_XS (QK_K / (4 * QR2_XS))
typedef struct {
half d;
uint16_t qs[QK_K / 8];
uint8_t scales[QK_K / 32];
} block_iq2_xs;
#define QR2_S 8
#define QI2_S (QK_K / (4 * QR2_S))
typedef struct {
half d;
uint8_t qs[QK_K / 4];
uint8_t qh[QK_K / 32];
uint8_t scales[QK_K / 32];
} block_iq2_s;
#define QR3_XXS 8
#define QI3_XXS (QK_K / (4 * QR3_XXS))
typedef struct {
half d;
uint8_t qs[3 * (QK_K / 8)];
} block_iq3_xxs;
#define QR3_XS 8
#define QI3_XS (QK_K / (4 * QR3_XS))
#define IQ3S_N_SCALE QK_K / 64
typedef struct {
half d;
uint8_t qs[QK_K / 4];
uint8_t qh[QK_K / 32];
uint8_t signs[QK_K / 8];
uint8_t scales[IQ3S_N_SCALE];
} block_iq3_s;
// 1.5625 bpw
#define QR1_S 8
#define QI1_S (QK_K / (4 * QR1_S))
typedef struct {
half d;
uint8_t qs[QK_K / 8];
uint16_t qh[QK_K / 32];
} block_iq1_s;
// 1.75 bpw
#define QR1_M 8
#define QI1_M (QK_K / (4 * QR1_M))
typedef struct {
uint8_t qs[QK_K / 8]; // grid index, low 8 bits
uint8_t qh[QK_K / 16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
uint8_t scales[QK_K / 32]; // 3-bit block scales (4-bit if QK_K == 64)
} block_iq1_m;
// Used by IQ1_M quants
typedef union {
half f16;
uint16_t u16;
} iq1m_scale_t;
#define QK4_NL 32
#define QR4_NL 2
#define QI4_NL (QK4_NL / (4 * QR4_NL))
typedef struct {
half d;
uint8_t qs[QK4_NL / 2];
} block_iq4_nl;
#define QR4_XS 8
#define QI4_XS (QK_K / (4 * QR4_XS))
typedef struct {
half d;
uint16_t scales_h;
uint8_t scales_l[QK_K / 64];
uint8_t qs[QK_K / 2];
} block_iq4_xs;
static const __device__ uint64_t iq2xxs_grid[256] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b,
0x0808080808190819, 0x0808080808191908, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b2b08,
0x08080808082b2b2b, 0x0808080819080819, 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08,
0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, 0x0808081908191919,
0x0808081919080808, 0x080808192b081908, 0x080808192b192b08, 0x0808082b08080808, 0x0808082b0808082b,
0x0808082b082b082b, 0x0808082b2b08082b, 0x0808190808080819, 0x0808190808081908, 0x0808190808190808,
0x08081908082b0819, 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x080819082b2b1908,
0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, 0x08081919082b0808, 0x080819191908192b,
0x08081919192b2b19, 0x080819192b080808, 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808,
0x0808192b19080808, 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819, 0x08082b0819081908,
0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08, 0x08082b1908081908, 0x08082b1919080808,
0x08082b2b0808082b, 0x08082b2b08191908, 0x0819080808080819, 0x0819080808081908, 0x0819080808190808,
0x08190808082b0819, 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, 0x0819081919190808,
0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908, 0x0819082b19081919, 0x0819190808080808,
0x0819190808082b08, 0x08191908082b0808, 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808,
0x0819191908192b08, 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819, 0x08192b1908080808,
0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, 0x082b080808080808, 0x082b08080808082b,
0x082b080808082b2b, 0x082b080819081908, 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b,
0x082b0819082b2b19, 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b, 0x082b191908080808,
0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808, 0x082b2b0808082b08, 0x082b2b08082b0808,
0x082b2b082b191908, 0x082b2b2b19081908, 0x1908080808080819, 0x1908080808081908, 0x1908080808190808,
0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819, 0x190808192b080808, 0x190808192b081919,
0x1908082b08080819, 0x1908082b08190808, 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08,
0x1908190808080808, 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819, 0x19082b0808081908,
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b1908080808, 0x19082b1919192b08,
0x19082b19192b0819, 0x19082b192b08082b, 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808,
0x1919080808082b08, 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908, 0x1919082b2b190819,
0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b, 0x1919192b08080819, 0x1919192b19191908,
0x19192b0808080808, 0x19192b0808190819, 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808,
0x19192b2b08082b08, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, 0x192b190808080808,
0x192b190808081919, 0x192b191908190808, 0x192b19190819082b, 0x192b19192b081908, 0x192b2b081908082b,
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b,
0x2b08081908081908, 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808,
0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, 0x2b08192b08082b19, 0x2b08192b19080808,
0x2b08192b192b0808, 0x2b082b080808082b, 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908,
0x2b19080808190808, 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b, 0x2b19190819081908,
0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808, 0x2b2b08080808082b, 0x2b2b080819190808,
0x2b2b08082b081919, 0x2b2b081908082b19, 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808,
0x2b2b2b1908081908,
};
static const __device__ uint64_t iq2xs_grid[512] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b,
0x0808080808190819, 0x0808080808191908, 0x080808080819192b, 0x0808080808192b19, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b1919, 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908,
0x080808081908192b, 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b,
0x080808082b081919, 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b192b19,
0x080808082b2b0808, 0x0808081908080819, 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19,
0x0808081908190808, 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b, 0x0808081919081919,
0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, 0x08080819192b0808, 0x08080819192b2b08,
0x080808192b080819, 0x080808192b081908, 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b,
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919, 0x0808082b2b080808,
0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, 0x0808190808082b19,
0x0808190808190808, 0x080819080819082b, 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819,
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, 0x080819082b080819,
0x080819082b081908, 0x080819082b190808, 0x0808191908080808, 0x080819190808082b, 0x0808191908081919,
0x0808191908082b08, 0x0808191908190819, 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819,
0x0808191919081908, 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808, 0x0808192b1908082b,
0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, 0x08082b0808081919, 0x08082b0808082b08,
0x08082b0808082b2b, 0x08082b0808190819, 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919,
0x08082b0819080819, 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908, 0x08082b1908190808,
0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19, 0x08082b2b08080808, 0x08082b2b082b0808,
0x08082b2b082b2b08, 0x08082b2b2b19192b, 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908,
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, 0x081908081908082b,
0x0819080819081919, 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x08190808192b0808,
0x08190808192b2b2b, 0x081908082b080819, 0x081908082b081908, 0x081908082b190808, 0x0819081908080808,
0x081908190808082b, 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808, 0x081908192b080808,
0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, 0x0819082b08081908, 0x0819082b0808192b,
0x0819082b08190808, 0x0819082b19080808, 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b,
0x0819190808081919, 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808, 0x08191908192b1908,
0x081919082b080808, 0x0819191908080819, 0x0819191908081908, 0x0819191908190808, 0x0819191919080808,
0x0819192b08080808, 0x0819192b08191908, 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908,
0x08192b0808190808, 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, 0x08192b2b2b2b2b19,
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08, 0x082b080808082b2b,
0x082b080808190819, 0x082b080808191908, 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908,
0x082b080819190808, 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919, 0x082b082b08080808,
0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
0x082b190808190808, 0x082b1908082b2b19, 0x082b190819080808, 0x082b191908080808, 0x082b191919080819,
0x082b19191919082b, 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808, 0x082b2b0819191919,
0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08, 0x082b2b192b190808, 0x082b2b2b08082b08,
0x082b2b2b082b0808, 0x082b2b2b2b08082b, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808,
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, 0x1908080819190819,
0x1908080819191908, 0x19080808192b0808, 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908,
0x190808082b190808, 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819, 0x1908081919081908,
0x1908081919190808, 0x190808192b080808, 0x190808192b081919, 0x190808192b2b082b, 0x1908082b08080819,
0x1908082b08081908, 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808,
0x1908190808080808, 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819, 0x1908190819081908,
0x1908190819190808, 0x190819082b080808, 0x190819082b191908, 0x1908191908080819, 0x1908191908081908,
0x1908191908190808, 0x19081919082b1908, 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808,
0x1908192b08082b2b, 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, 0x19082b08192b082b,
0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908, 0x19082b1919190808, 0x19082b19192b2b19,
0x19082b2b08081908, 0x1919080808080808, 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08,
0x1919080808190819, 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819, 0x1919081908081908,
0x1919081908190808, 0x1919081908191919, 0x1919081919080808, 0x191908191908082b, 0x1919082b08080808,
0x1919082b19081908, 0x1919082b2b2b2b2b, 0x1919190808080819, 0x1919190808081908, 0x1919190808190808,
0x19191908082b0819, 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08, 0x1919192b082b0819,
0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808, 0x19192b0808191908, 0x19192b0819080819,
0x19192b0819190808, 0x19192b082b192b19, 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b,
0x19192b2b2b081919, 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, 0x192b081908080808,
0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b, 0x192b190808080808, 0x192b19080819192b,
0x192b191908190808, 0x192b191919080808, 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819,
0x192b2b08192b2b2b, 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, 0x2b08080819080819, 0x2b08080819081908,
0x2b08080819190808, 0x2b0808082b080808, 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b,
0x2b08081908080819, 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808, 0x2b08082b2b080808,
0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, 0x2b08190808080819, 0x2b08190808081908,
0x2b08190808190808, 0x2b0819080819082b, 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808,
0x2b0819082b082b19, 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, 0x2b082b0819192b2b,
0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08, 0x2b082b190808192b, 0x2b082b2b082b082b,
0x2b082b2b2b080808, 0x2b082b2b2b082b08, 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819,
0x2b19080808081908, 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b, 0x2b19082b2b082b19,
0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, 0x2b19190819190808, 0x2b19190819192b08,
0x2b191919082b2b19, 0x2b1919192b190808, 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819,
0x2b192b082b2b192b, 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808, 0x2b2b0808082b2b2b,
0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19, 0x2b2b08192b2b192b, 0x2b2b082b08080808,
0x2b2b082b0808082b, 0x2b2b082b08082b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808,
0x2b2b190819080808, 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, 0x2b2b2b082b2b2b08,
0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b,
0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
static const __device__ uint64_t iq2s_grid[1024] = {
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08, 0x0808080808082b2b,
0x0808080808190819, 0x0808080808191908, 0x080808080819192b, 0x0808080808192b19, 0x08080808082b0808,
0x08080808082b082b, 0x08080808082b1919, 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908,
0x080808081908192b, 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b, 0x08080808192b2b19,
0x080808082b080808, 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
0x080808082b191908, 0x080808082b2b0808, 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819,
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808,
0x080808191908082b, 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
0x080808191919192b, 0x0808081919192b19, 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08,
0x080808192b080819, 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, 0x0808082b082b2b2b,
0x0808082b19080819, 0x0808082b19081908, 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808,
0x0808082b19191919, 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, 0x0808190808082b19,
0x0808190808190808, 0x080819080819082b, 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819,
0x08081908082b1908, 0x08081908082b192b, 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b,
0x0808190819081919, 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b, 0x08081908192b1919,
0x080819082b080819, 0x080819082b081908, 0x080819082b08192b, 0x080819082b082b19, 0x080819082b190808,
0x080819082b191919, 0x080819082b192b08, 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808,
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808, 0x08081919082b1919,
0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908, 0x080819191908192b, 0x0808191919082b19,
0x0808191919190808, 0x080819191919082b, 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819,
0x08081919192b1908, 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819, 0x0808192b08081908,
0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b08191919, 0x0808192b19080808,
0x0808192b19081919, 0x0808192b19082b08, 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808,
0x0808192b2b080819, 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908, 0x08082b080819192b,
0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b08082b2b2b, 0x08082b0819080819,
0x08082b0819081908, 0x08082b081908192b, 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b,
0x08082b0819191919, 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919, 0x08082b1908192b08, 0x08082b19082b0819,
0x08082b1919080808, 0x08082b1919081919, 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908,
0x08082b19192b0808, 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b, 0x08082b2b19190808,
0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b, 0x0819080808082b19,
0x0819080808190808, 0x081908080819082b, 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819,
0x08190808082b1908, 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b, 0x0819080819192b19,
0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919, 0x08190808192b2b08, 0x081908082b080819,
0x081908082b081908, 0x081908082b08192b, 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08,
0x081908082b2b0819, 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908, 0x081908190819192b,
0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b, 0x08190819082b1919, 0x08190819082b2b08,
0x0819081919080819, 0x0819081919081908, 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808,
0x081908191919082b, 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08, 0x081908192b190819,
0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908, 0x0819082b08082b19, 0x0819082b08190808,
0x0819082b08191919, 0x0819082b082b0819, 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919,
0x0819082b19190819, 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, 0x0819190808190819,
0x0819190808191908, 0x081919080819192b, 0x0819190808192b19, 0x08191908082b0808, 0x08191908082b1919,
0x08191908082b2b08, 0x0819190819080819, 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19,
0x0819190819190808, 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919, 0x081919082b082b08,
0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808, 0x0819191908080819, 0x0819191908081908,
0x081919190808192b, 0x0819191908082b19, 0x0819191908190808, 0x081919190819082b, 0x0819191908191919,
0x0819191908192b08, 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908, 0x08191919192b0808,
0x081919192b080819, 0x081919192b081908, 0x081919192b190808, 0x0819192b08080808, 0x0819192b08081919,
0x0819192b08082b08, 0x0819192b08190819, 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819,
0x0819192b19081908, 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808, 0x08192b0808191919,
0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808, 0x08192b081908082b, 0x08192b0819081919,
0x08192b0819082b08, 0x08192b0819190819, 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819,
0x08192b082b081908, 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819, 0x08192b1919081908,
0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b, 0x08192b2b08081908, 0x08192b2b08190808,
0x08192b2b19080808, 0x08192b2b1919192b, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
0x082b080808082b08, 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819, 0x082b080819081908,
0x082b080819190808, 0x082b08081919082b, 0x082b080819191919, 0x082b0808192b1908, 0x082b08082b080808,
0x082b08082b082b2b, 0x082b08082b191908, 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908,
0x082b081908190808, 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908, 0x082b0819192b0808,
0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808, 0x082b082b08080808, 0x082b082b08082b2b,
0x082b082b082b082b, 0x082b082b082b2b08, 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808,
0x082b082b2b082b08, 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919, 0x082b190808192b08,
0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808, 0x082b19081908082b, 0x082b190819081919,
0x082b190819082b08, 0x082b190819190819, 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819,
0x082b19082b081908, 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819, 0x082b191919081908,
0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808, 0x082b192b08080819, 0x082b192b08081908,
0x082b192b08190808, 0x082b192b19080808, 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919,
0x082b2b0808190819, 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908, 0x082b2b1908190808,
0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b, 0x082b2b2b192b1908, 0x082b2b2b2b082b08,
0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19,
0x1908080808190808, 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808, 0x190808081908082b,
0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908,
0x190808081919192b, 0x1908080819192b19, 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919,
0x190808082b080819, 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b, 0x1908081908081919,
0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, 0x190808190819192b, 0x1908081908192b19,
0x19080819082b0808, 0x19080819082b082b, 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908,
0x190808191908192b, 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808, 0x190808192b08082b,
0x190808192b081919, 0x190808192b082b08, 0x190808192b190819, 0x190808192b191908, 0x190808192b2b0808,
0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919,
0x1908082b08192b08, 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819, 0x1908082b2b081908,
0x1908190808080808, 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808082b2b,
0x1908190808190819, 0x1908190808191908, 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808,
0x19081908082b082b, 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b, 0x1908190819191919,
0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908, 0x190819082b080808, 0x190819082b08082b,
0x190819082b081919, 0x190819082b082b08, 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808,
0x1908191908080819, 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819, 0x19081919082b1908,
0x1908191919080808, 0x190819191908082b, 0x1908191919081919, 0x1908191919082b08, 0x1908191919190819,
0x1908191919191908, 0x19081919192b0808, 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908,
0x190819192b190808, 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819, 0x1908192b19081908,
0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919, 0x19082b0808080819, 0x19082b0808081908,
0x19082b0808082b19, 0x19082b0808190808, 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08,
0x19082b08082b0819, 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808, 0x19082b082b081908,
0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b, 0x19082b1908081919, 0x19082b1908082b08,
0x19082b1908190819, 0x19082b1908191908, 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908,
0x19082b1919190808, 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b, 0x1919080808081919,
0x1919080808082b08, 0x1919080808190819, 0x1919080808191908, 0x191908080819192b, 0x1919080808192b19,
0x19190808082b0808, 0x19190808082b082b, 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819,
0x1919080819081908, 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908, 0x191908082b080808,
0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08, 0x191908082b190819, 0x191908082b191908,
0x1919081908080819, 0x1919081908081908, 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808,
0x191908190819082b, 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08, 0x1919081919190819,
0x1919081919191908, 0x19190819192b0808, 0x191908192b080819, 0x191908192b081908, 0x191908192b190808,
0x1919082b08080808, 0x1919082b08081919, 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908,
0x1919082b082b0808, 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b, 0x1919190808082b19,
0x1919190808190808, 0x191919080819082b, 0x1919190808191919, 0x1919190808192b08, 0x19191908082b0819,
0x19191908082b1908, 0x1919190819080808, 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08,
0x1919190819190819, 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919, 0x1919191908082b08,
0x1919191908190819, 0x1919191908191908, 0x19191919082b0808, 0x1919191919080819, 0x1919191919081908,
0x1919191919190808, 0x191919192b080808, 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808,
0x1919192b082b192b, 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808, 0x19192b0819080819,
0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b, 0x19192b082b080808, 0x19192b1908080819,
0x19192b1908081908, 0x19192b1908190808, 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19,
0x19192b2b2b081919, 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08, 0x192b0808082b0819,
0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919, 0x192b080819082b08, 0x192b080819190819,
0x192b080819191908, 0x192b0808192b0808, 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808,
0x192b08190808082b, 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808, 0x192b08192b080808,
0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808, 0x192b082b19080808, 0x192b082b1919192b,
0x192b082b2b2b0819, 0x192b190808080808, 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819,
0x192b190808191908, 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808, 0x192b191919080808,
0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b, 0x192b192b08080808, 0x192b192b2b191908,
0x192b2b0808080819, 0x192b2b0808081908, 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08,
0x192b2b1908080808, 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, 0x2b08080808191908,
0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919, 0x2b08080819080819, 0x2b08080819081908,
0x2b08080819190808, 0x2b0808081919082b, 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819,
0x2b0808082b080808, 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b, 0x2b08081908191919,
0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908, 0x2b08081919080808, 0x2b0808191908082b,
0x2b08081919081919, 0x2b08081919082b08, 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819,
0x2b0808192b081908, 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819, 0x2b08082b19081908,
0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908, 0x2b0819080808192b, 0x2b08190808082b19,
0x2b08190808190808, 0x2b0819080819082b, 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819,
0x2b08190819080808, 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908, 0x2b0819082b190808,
0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919, 0x2b08191908082b08, 0x2b08191908190819,
0x2b08191908191908, 0x2b081919082b0808, 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808,
0x2b0819192b080808, 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919, 0x2b082b0808190819,
0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908, 0x2b082b0819190808, 0x2b082b082b2b082b,
0x2b082b1908080819, 0x2b082b1908081908, 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b,
0x2b082b2b19192b08, 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b, 0x2b19080808191919,
0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908081908082b, 0x2b19080819081919,
0x2b19080819082b08, 0x2b19080819190819, 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819,
0x2b1908082b081908, 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808, 0x2b19081919192b2b,
0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808, 0x2b19082b19080808, 0x2b19082b2b2b192b,
0x2b19190808080808, 0x2b1919080808082b, 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819,
0x2b19190808191908, 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908, 0x2b19191908190808,
0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819, 0x2b19192b08080808, 0x2b19192b1908192b,
0x2b19192b192b1908, 0x2b192b0808080819, 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b,
0x2b192b0819080808, 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b, 0x2b2b080808191908,
0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819, 0x2b2b080819081908, 0x2b2b080819190808,
0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b, 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b,
0x2b2b082b08082b2b, 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819, 0x2b2b190808081908,
0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19, 0x2b2b19082b2b1908, 0x2b2b191908080808,
0x2b2b191908192b19, 0x2b2b192b19190819, 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b,
0x2b2b2b1919191908, 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
};
static const __device__ uint32_t iq3xxs_grid[256] = {
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c,
0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, 0x040c140c, 0x040c142c,
0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, 0x04140414, 0x04140424, 0x04140c0c,
0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c,
0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e,
0x04243e1c, 0x04243e2c, 0x042c040c, 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04,
0x043e0c24, 0x043e0c34, 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c,
0x0c04141c, 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, 0x0c143e14,
0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, 0x0c24042c, 0x0c242c04,
0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, 0x0c3e2404, 0x14040404, 0x14040414,
0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c,
0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404,
0x14140414, 0x14140c0c, 0x14140c3e, 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c,
0x141c0c04, 0x141c0c24, 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c,
0x142c3e24, 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, 0x1c0c2424,
0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, 0x1c1c0c0c, 0x1c1c1c1c,
0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, 0x1c2c2c2c, 0x1c340c24, 0x1c341c34,
0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e,
0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424,
0x24242c0c, 0x24243424, 0x242c142c, 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04,
0x2c040c14, 0x2c04240c, 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14,
0x2c143e14, 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, 0x340c340c,
0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, 0x34341c1c, 0x343e041c,
0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, 0x3e042c14, 0x3e0c1434, 0x3e0c2404,
0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c,
0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const __device__ uint32_t iq3xs_grid[512] = {
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14, 0x04040c24,
0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414, 0x0404242c, 0x0404243e,
0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24, 0x04043e3e, 0x040c0404, 0x040c040c,
0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c, 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c,
0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c, 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e,
0x04140c04, 0x04140c1c, 0x04140c34, 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c,
0x0414243e, 0x04142c0c, 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404,
0x041c1414, 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404, 0x0424241c,
0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434, 0x042c1c1c, 0x042c240c,
0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c, 0x04340c1c, 0x04341c0c, 0x04342c14,
0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404, 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04,
0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414, 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c,
0x0c040c3e, 0x0c041404, 0x0c041414, 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c,
0x0c043e14, 0x0c0c0404, 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04,
0x0c0c1c1c, 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404, 0x0c143e14,
0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e, 0x0c1c1c04, 0x0c1c1c24,
0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14, 0x0c240c24, 0x0c241c0c, 0x0c241c1c,
0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c, 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04,
0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424, 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c,
0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c, 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c,
0x1404041c, 0x1404042c, 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c,
0x1404143e, 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e, 0x140c1414,
0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424, 0x1414043e, 0x1414140c,
0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e, 0x14143e0c, 0x14143e24, 0x141c0404,
0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424, 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04,
0x141c3434, 0x1424040c, 0x1424043e, 0x14241404, 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14,
0x14243e2c, 0x142c0424, 0x142c0c0c, 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404,
0x14340414, 0x1434043e, 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04,
0x143e241c, 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c, 0x1c0c040c,
0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404, 0x1c0c3e14, 0x1c0c3e34,
0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04, 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24,
0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c, 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c,
0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414, 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c,
0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c, 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c,
0x1c3e040c, 0x1c3e041c, 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404,
0x24041424, 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c, 0x2414041c,
0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414, 0x24143e04, 0x241c0424,
0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c, 0x24240404, 0x24240414, 0x24241424,
0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e, 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24,
0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04, 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24,
0x2c041414, 0x2c042404, 0x2c042424, 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c,
0x2c0c042c, 0x2c0c0c14, 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04,
0x2c141c34, 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434, 0x2c2c2c0c,
0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c, 0x34040c2c, 0x34041c0c,
0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424, 0x34140c14, 0x34141c24, 0x34142414,
0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24, 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c,
0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24, 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c,
0x3e040404, 0x3e040424, 0x3e04043e, 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414,
0x3e0c0414, 0x3e0c0c0c, 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34,
0x3e14140c, 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
};
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
static const __device__ uint64_t iq1s_grid_gpu[2048] = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002,
0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, 0x02000000, 0x02000002,
0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, 0x02020202, 0x00000110, 0x00000111,
0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, 0x00020111, 0x01000011, 0x01000112, 0x01000211,
0x01010012, 0x01010111, 0x01010212, 0x01020011, 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011,
0x02010110, 0x02010112, 0x02020111, 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020,
0x00020022, 0x00020220, 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020,
0x02000022, 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, 0x01011202,
0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, 0x00001111, 0x00001112,
0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, 0x01001212, 0x01011010, 0x01011011,
0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, 0x01021012, 0x01021111, 0x01021210, 0x01021212,
0x02001011, 0x02011011, 0x02011111, 0x02011210, 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112,
0x02021211, 0x00011120, 0x00011221, 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220,
0x01021020, 0x01021021, 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000,
0x00002002, 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, 0x02022000,
0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, 0x00022110, 0x00022111,
0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, 0x01022211, 0x02012011, 0x02012110,
0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, 0x00002220, 0x00002222, 0x00012121, 0x00022020,
0x00022022, 0x00022220, 0x00022222, 0x01002121, 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020,
0x02002022, 0x02002121, 0x02002220, 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222,
0x00110000, 0x00110001, 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000,
0x01110101, 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, 0x00110111,
0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, 0x01110011, 0x01110012,
0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, 0x02100110, 0x02110012, 0x02110111,
0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, 0x00120121, 0x01100020, 0x01100122, 0x01100221,
0x01110022, 0x01110121, 0x01110220, 0x01110222, 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120,
0x02110122, 0x02120121, 0x00101001, 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201,
0x00121001, 0x00121102, 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100,
0x01111101, 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, 0x02121201,
0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, 0x00111211, 0x00121010,
0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, 0x01101111, 0x01101112, 0x01111011,
0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, 0x01111212, 0x01121011, 0x01121110, 0x01121111,
0x01121112, 0x01121211, 0x02101010, 0x02101012, 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010,
0x02111011, 0x02111110, 0x02111111, 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111,
0x00101021, 0x00101120, 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021,
0x00121122, 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, 0x01121222,
0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, 0x00112102, 0x00122101,
0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, 0x01112200, 0x01112202, 0x01122000,
0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, 0x02112001, 0x02112100, 0x02122101, 0x00112010,
0x00112012, 0x00112111, 0x00112212, 0x00122011, 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210,
0x01112011, 0x01112110, 0x01112111, 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212,
0x02102211, 0x02112011, 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221,
0x00112122, 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, 0x00200000,
0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, 0x00220200, 0x00220202,
0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, 0x02200002, 0x02200200, 0x02200202,
0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, 0x02220202, 0x00200111, 0x00210011, 0x00210110,
0x00210211, 0x00220111, 0x01200012, 0x01200110, 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011,
0x01220110, 0x01220111, 0x01220112, 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021,
0x00200220, 0x00200222, 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121,
0x01210021, 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, 0x00221101,
0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, 0x01211202, 0x01221102,
0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, 0x00201211, 0x00211111, 0x00221011,
0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, 0x01211110, 0x01211111, 0x01211211, 0x01221012,
0x01221111, 0x01221210, 0x02201211, 0x02211010, 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011,
0x02221110, 0x02221112, 0x02221211, 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021,
0x01201221, 0x01211121, 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222,
0x00202000, 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, 0x02222000,
0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, 0x00222111, 0x01202112,
0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, 0x01222211, 0x02202111, 0x02212010,
0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, 0x00202022, 0x00202220, 0x00202222, 0x00222020,
0x00222022, 0x00222220, 0x00222222, 0x01202121, 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020,
0x02202022, 0x02202220, 0x02202222, 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101,
0x10010001, 0x10010102, 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001,
0x11020100, 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, 0x10020112,
0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, 0x11010112, 0x11010211,
0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, 0x12000112, 0x12010010, 0x12010012,
0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, 0x10010021, 0x10010120, 0x10010122, 0x10020121,
0x11000021, 0x11010022, 0x11010121, 0x11010222, 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121,
0x10001001, 0x10011101, 0x10011201, 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100,
0x11011101, 0x11011102, 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102,
0x12001201, 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, 0x10021111,
0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, 0x11011011, 0x11011110,
0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, 0x11021111, 0x11021112, 0x11021211,
0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, 0x12011110, 0x12011111, 0x12011112, 0x12011211,
0x12011212, 0x12021111, 0x12021210, 0x12021212, 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121,
0x10011220, 0x10011222, 0x10021021, 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220,
0x11011020, 0x11011021, 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220,
0x12001021, 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, 0x11012200,
0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, 0x12012102, 0x12012201,
0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, 0x10012110, 0x10012111, 0x10012210,
0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, 0x11002212, 0x11012011, 0x11012012, 0x11012110,
0x11012111, 0x11012112, 0x11012211, 0x11022010, 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112,
0x12002211, 0x12012012, 0x12012111, 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211,
0x10012122, 0x11002120, 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221,
0x12012120, 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, 0x11110100,
0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, 0x12110101, 0x12110200,
0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, 0x10100211, 0x10100212, 0x10110011,
0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, 0x10120010, 0x10120111, 0x10120112, 0x10120210,
0x10120212, 0x11100011, 0x11100110, 0x11100111, 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012,
0x11110110, 0x11110111, 0x11110112, 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111,
0x11120112, 0x11120211, 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211,
0x12120010, 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, 0x11110221,
0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, 0x12110222, 0x12120120,
0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, 0x10111200, 0x10111201, 0x10121001,
0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, 0x11101101, 0x11101102, 0x11101201, 0x11101202,
0x11111000, 0x11111001, 0x11111100, 0x11111101, 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001,
0x11121002, 0x11121100, 0x11121101, 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001,
0x12111100, 0x12111101, 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011,
0x10101012, 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, 0x10121211,
0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, 0x11101211, 0x11111010,
0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, 0x11111211, 0x11111212, 0x11121010,
0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, 0x11121211, 0x11121212, 0x12101011, 0x12101110,
0x12101111, 0x12101211, 0x12101212, 0x12111010, 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210,
0x12111211, 0x12121011, 0x12121110, 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022,
0x10101120, 0x10101122, 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221,
0x10121020, 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, 0x11111120,
0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, 0x11121121, 0x11121221,
0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, 0x12111021, 0x12111121, 0x12111222,
0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, 0x10102100, 0x10102101, 0x10102102, 0x10102201,
0x10112000, 0x10112101, 0x10112200, 0x10122001, 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001,
0x11112100, 0x11112101, 0x11112102, 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101,
0x12102002, 0x12102201, 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011,
0x10102012, 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, 0x11112110,
0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, 0x11122111, 0x11122112,
0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, 0x12112111, 0x12112112, 0x12112210,
0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, 0x10112222, 0x10122020, 0x10122121, 0x10122122,
0x10122221, 0x11102121, 0x11102220, 0x11102221, 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221,
0x11122022, 0x11122121, 0x11122220, 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122,
0x12112220, 0x12112222, 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100,
0x11210000, 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, 0x10210111,
0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, 0x11210111, 0x11210112,
0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, 0x12210012, 0x12210111, 0x12220011,
0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, 0x11200020, 0x11200021, 0x11200122, 0x11210121,
0x11210122, 0x11210220, 0x11220020, 0x12200121, 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002,
0x10211101, 0x10211102, 0x10211202, 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101,
0x11201200, 0x11201202, 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000,
0x11221002, 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, 0x10201212,
0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, 0x11201211, 0x11211010,
0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, 0x11221110, 0x11221111, 0x11221112,
0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, 0x12211111, 0x12211112, 0x12211211, 0x12211212,
0x12221012, 0x12221111, 0x12221112, 0x12221210, 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122,
0x10221220, 0x10221221, 0x11201020, 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121,
0x11211122, 0x11211220, 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121,
0x12201222, 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, 0x11222201,
0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, 0x10212111, 0x10222011,
0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, 0x11202112, 0x11202210, 0x11212011,
0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, 0x11222111, 0x11222212, 0x12202012, 0x12202110,
0x12202212, 0x12212111, 0x12222011, 0x12222110, 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220,
0x11202021, 0x11202120, 0x11202221, 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121,
0x11222221, 0x12202122, 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200,
0x20000202, 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, 0x22020000,
0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, 0x20010211, 0x20020111,
0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, 0x21010112, 0x21010210, 0x21010211,
0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, 0x22010110, 0x22010112, 0x22010211, 0x22020111,
0x20000020, 0x20000022, 0x20000220, 0x20000222, 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222,
0x21010021, 0x21010120, 0x21010221, 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121,
0x22020020, 0x22020022, 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001,
0x21011101, 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, 0x21001210,
0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, 0x21021112, 0x21021210,
0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, 0x22011012, 0x22011111, 0x22011210,
0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, 0x21001021, 0x21001120, 0x21001221, 0x21001222,
0x21011020, 0x21011121, 0x21011221, 0x21011222, 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021,
0x22011222, 0x22021120, 0x20002000, 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002,
0x20022200, 0x20022202, 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201,
0x22002000, 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, 0x21002112,
0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, 0x22002111, 0x22012112,
0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, 0x20012121, 0x20022020, 0x20022022,
0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, 0x21012122, 0x22002020, 0x22002022, 0x22002220,
0x22002222, 0x22012121, 0x22022020, 0x22022022, 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102,
0x20110200, 0x20110201, 0x20120101, 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202,
0x21120201, 0x21120202, 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011,
0x20100110, 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, 0x21110112,
0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, 0x22110210, 0x22120011,
0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, 0x20110221, 0x20120121, 0x21100120,
0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, 0x21110220, 0x21120122, 0x21120221, 0x22100121,
0x22110120, 0x22110122, 0x22120221, 0x20101001, 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200,
0x20121102, 0x21101000, 0x21101202, 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201,
0x21121000, 0x21121001, 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101,
0x22111200, 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, 0x21101011,
0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, 0x21111110, 0x21111111,
0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, 0x21121111, 0x21121112, 0x21121211,
0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, 0x22111110, 0x22111111, 0x22111112, 0x22111211,
0x22111212, 0x22121010, 0x22121012, 0x22121111, 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020,
0x20111121, 0x20111221, 0x20121020, 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021,
0x21111022, 0x21111121, 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221,
0x22101222, 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, 0x21112202,
0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, 0x20102110, 0x20102112,
0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, 0x20122010, 0x20122011, 0x20122110,
0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, 0x21102212, 0x21112011, 0x21112110, 0x21112111,
0x21112112, 0x21112211, 0x21122012, 0x21122111, 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010,
0x22112012, 0x22112111, 0x22112212, 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120,
0x21102122, 0x21102221, 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120,
0x22112121, 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, 0x22200002,
0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, 0x20200111, 0x20200211,
0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, 0x21200211, 0x21210011, 0x21210111,
0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, 0x22210010, 0x22210012, 0x22210112, 0x22210211,
0x20200022, 0x20200220, 0x20200222, 0x20210020, 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121,
0x21210021, 0x21210122, 0x21210221, 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121,
0x22220020, 0x22220022, 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000,
0x21211100, 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, 0x20221211,
0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, 0x21221111, 0x21221212,
0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, 0x22211111, 0x22211210, 0x20201121,
0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, 0x21201120, 0x21201122, 0x21201222, 0x21211022,
0x21211121, 0x21211122, 0x21211220, 0x21221020, 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122,
0x22211221, 0x22221021, 0x22221120, 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000,
0x20222002, 0x20222200, 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002,
0x22202200, 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, 0x21222112,
0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, 0x20222020, 0x20222022,
0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222,
0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
};
static const __device__ uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15, 144, 17, 18, 147, 20, 149,
150, 23, 24, 153, 154, 27, 156, 29, 30, 159, 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43,
172, 45, 46, 175, 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, 192, 65,
66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207, 80, 209, 210, 83, 212, 85, 86, 215,
216, 89, 90, 219, 92, 221, 222, 95, 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237,
238, 111, 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
};
static const __device__ uint64_t ksigns64[128] = {
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0xff00000000ff0000,
0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, 0xff000000ff000000, 0x00000000ff0000ff,
0x00000000ff00ff00, 0xff000000ff00ffff, 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00,
0x00000000ffffffff, 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, 0x000000ffff000000,
0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, 0xff0000ffffff0000, 0x000000ffffff00ff,
0x000000ffffffff00, 0xff0000ffffffffff, 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00,
0xff00ff000000ffff, 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, 0xff00ff00ffff0000,
0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, 0x0000ffff00000000, 0xff00ffff000000ff,
0xff00ffff0000ff00, 0x0000ffff0000ffff, 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00,
0xff00ffff00ffffff, 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, 0xffff000000000000,
0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, 0x00ff000000ff0000, 0xffff000000ff00ff,
0xffff000000ffff00, 0x00ff000000ffffff, 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00,
0x00ff0000ff00ffff, 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, 0xffff00ff00ff0000,
0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, 0xffff00ffff000000, 0x00ff00ffff0000ff,
0x00ff00ffff00ff00, 0xffff00ffff00ffff, 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00,
0x00ff00ffffffffff, 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, 0xffffff00ff000000,
0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, 0x00ffff00ffff0000, 0xffffff00ffff00ff,
0xffffff00ffffff00, 0x00ffff00ffffffff, 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00,
0xffffffff0000ffff, 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, 0xffffffffffff0000,
0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
};
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
static const __device__ int8_t kvalues_iq4nl[16] = {
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
typedef void (*dequantize_kernel_t)(const void* vx, const int ib, const int iqs, dfloat2& v);
template <typename dst_t>
using to_cuda_ggml_t = void (*)(const void* __restrict__ x, dst_t* __restrict__ y, int k, cudaStream_t stream);
typedef float (*vec_dot_q_cuda_t)(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs);
typedef void (*allocate_tiles_cuda_t)(int** x_ql, half2** x_dm, int** x_qh, int** x_sc);
typedef void (*load_tiles_cuda_t)(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row);
typedef float (*vec_dot_q_mul_mat_cuda_t)(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ms,
const int& i,
const int& j,
const int& k);
// Utility function
template <typename dst_t>
static __device__ __forceinline__ dst_t convert_from_half(half val) {
return val;
}
template <>
__device__ __forceinline__ c10::BFloat16 convert_from_half<c10::BFloat16>(half val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __float2bfloat16(__half2float(val));
#else
return __half2float(val);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
template <>
__device__ __forceinline__ float convert_from_half<float>(half val) {
return __half2float(val);
}
#if defined(USE_ROCM)
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int&>(c);
#else
int8x4_t c;
int16_t tmp;
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp = va[i] - vb[i];
if (tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
if (tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int&>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if __has_builtin(__builtin_amdgcn_sdot4)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
}
static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) {
uint32_t neq = a ^ b;
return !(neq & 0xff000000) * 0xff000000 | !(neq & 0x00ff0000) * 0x00ff0000 | !(neq & 0x0000ff00) * 0x0000ff00 |
!(neq & 0x000000ff) * 0x000000ff;
}
static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) {
return (static_cast<uint8_t>(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) +
(static_cast<uint8_t>(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) +
(static_cast<uint8_t>(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) +
(static_cast<uint8_t>(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0);
}
#endif // defined(USE_ROCM)
// Adatped from
// https://github.com/vllm-project/vllm/blob/755ed7b05be4743237d3339c4ff8c22bcaae04f4/csrc/quantization/gguf/gguf_kernel.cu
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
// dont use clang-format here, it breaks the include order
// clang-format off
#include "utils.h"
#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
#include "moe_vec.cuh"
// clang-format off
// Q8 gemv
template <typename scalar_t>
static __global__ void
quantize_q8_1(const scalar_t* __restrict__ x, void* __restrict__ vy, const int kx, const int kx_padded) {
const auto ix = blockDim.x * blockIdx.x + threadIdx.x;
if (ix >= kx_padded) {
return;
}
const auto iy = blockDim.y * blockIdx.y + threadIdx.y;
const int i_padded = iy * kx_padded + ix;
block_q8_1* y = (block_q8_1*)vy;
const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index
const float xi = ix < kx ? static_cast<float>(x[iy * kx + ix]) : 0.0f;
float amax = fabsf(xi);
float sum = xi;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, SGLANG_SHFL_XOR_SYNC_WIDTH(uint32_t(-1), amax, mask, 32));
sum += SGLANG_SHFL_XOR_SYNC_WIDTH(uint32_t(-1), sum, mask, 32);
}
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
y[ib].ds.x = __float2half(d);
y[ib].ds.y = __float2half(sum);
}
template <typename scalar_t>
static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, const int ky, cudaStream_t stream) {
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
constexpr int MAX_BLOCK_SIZE = 65535;
for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) {
const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off;
const dim3 num_blocks(block_num_x, num_blocks_y, 1);
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(
&x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded);
}
}
torch::Tensor ggml_dequantize(
torch::Tensor W, // quant weight
int64_t type,
int64_t m,
int64_t n,
std::optional<at::ScalarType> const& dtype) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto dtype_ = dtype.value_or(torch::kFloat16);
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
at::Tensor DW = torch::empty({m, n}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_FLOAT_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
});
return DW;
}
torch::Tensor ggml_mul_mat_vec_a8(
torch::Tensor W, // quant weight
torch::Tensor X, // input
int64_t type,
int64_t row) {
int col = X.sizes()[1];
int vecs = X.sizes()[0];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({vecs, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options);
DISPATCH_FLOAT_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream);
switch (type) {
case 2:
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 3:
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 6:
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 7:
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 8:
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 10:
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 11:
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 12:
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 13:
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 14:
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 16:
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 17:
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 18:
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 19:
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 20:
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 21:
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 22:
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 23:
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
case 29:
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (scalar_t*)Y.data_ptr(), col, row, vecs, stream);
break;
}
});
return Y;
}
torch::Tensor ggml_mul_mat_a8(
torch::Tensor W, // quant weight
torch::Tensor X, // input
int64_t type,
int64_t row) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
int batch = X.sizes()[0];
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({batch, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
DISPATCH_FLOAT_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, batch, stream);
switch (type) {
case 2:
ggml_mul_mat_q4_0_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 3:
ggml_mul_mat_q4_1_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 6:
ggml_mul_mat_q5_0_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 7:
ggml_mul_mat_q5_1_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 8:
ggml_mul_mat_q8_0_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 10:
ggml_mul_mat_q2_K_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 11:
ggml_mul_mat_q3_K_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 12:
ggml_mul_mat_q4_K_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 13:
ggml_mul_mat_q5_K_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
case 14:
ggml_mul_mat_q6_K_q8_1_cuda(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
col,
row,
batch,
padded,
row,
stream);
break;
}
});
return Y;
}
torch::Tensor ggml_moe_a8(
torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
DISPATCH_FLOAT_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream);
switch (type) {
case 2:
ggml_moe_q4_0_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 3:
ggml_moe_q4_1_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 6:
ggml_moe_q5_0_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 7:
ggml_moe_q5_1_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 8:
ggml_moe_q8_0_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 10:
ggml_moe_q2_K_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 11:
ggml_moe_q3_K_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 12:
ggml_moe_q4_K_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 13:
ggml_moe_q5_K_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
case 14:
ggml_moe_q6_K_q8_1_cuda(
(void*)quant_X.data_ptr(),
(void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(),
W.stride(0),
col,
row,
tokens,
padded,
row,
top_k,
sorted_token_ids.sizes()[0],
stream);
break;
}
});
return Y;
}
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor topk_ids,
int64_t top_k,
int64_t type,
int64_t row,
int64_t tokens) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
DISPATCH_FLOAT_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, tokens, stream);
switch (type) {
case 2:
moe_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 3:
moe_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 6:
moe_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 7:
moe_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 8:
moe_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 10:
moe_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 11:
moe_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 12:
moe_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 13:
moe_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 14:
moe_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 16:
moe_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 17:
moe_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 18:
moe_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 19:
moe_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 20:
moe_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 21:
moe_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 22:
moe_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 23:
moe_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
case 29:
moe_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(),
(void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(),
(int*)topk_ids.data_ptr(),
top_k,
tokens,
col,
row,
quant_X.stride(0),
stream);
break;
}
});
return Y;
}
int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) {
case 2:
return MOE_X_Q4_0;
case 3:
return MOE_X_Q4_1;
case 6:
return MOE_X_Q5_0;
case 7:
return MOE_X_Q5_1;
case 8:
return MOE_X_Q8_0;
case 10:
return MOE_X_Q2_K;
case 11:
return MOE_X_Q3_K;
case 12:
return MOE_X_Q4_K;
case 13:
return MOE_X_Q5_K;
case 14:
return MOE_X_Q6_K;
}
return 0;
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/mmq.cuh
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
template <
typename scalar_t,
int qk,
int qr,
int qi,
bool need_sum,
typename block_q_t,
int mmq_x,
int mmq_y,
int nwarps,
allocate_tiles_cuda_t allocate_tiles,
load_tiles_cuda_t load_tiles,
int vdr,
vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const block_q_t* x = (const block_q_t*)vx;
const block_q8_1* y = (const block_q8_1*)vy;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = WARP_SIZE_GGUF / qi;
const int& ncols_dst = ncols_y;
const auto row_dst_0 = blockIdx.x * mmq_y;
const int& row_x_0 = row_dst_0;
const auto col_dst_0 = blockIdx.y * mmq_x;
const int& col_y_0 = col_dst_0;
int* tile_x_ql = nullptr;
half2* tile_x_dm = nullptr;
int* tile_x_qh = nullptr;
int* tile_x_sc = nullptr;
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF / QI8_1];
float sum[mmq_y / WARP_SIZE_GGUF][mmq_x / nwarps] = {{0.0f}};
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
load_tiles(
x + row_x_0 * blocks_per_row_x + ib0,
tile_x_ql,
tile_x_dm,
tile_x_qh,
tile_x_sc,
threadIdx.y,
nrows_x - row_x_0 - 1,
threadIdx.x,
blocks_per_row_x);
#pragma unroll
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp / qr < blocks_per_row_x; ++ir) {
const auto kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;
#pragma unroll
for (int i = 0; i < mmq_x; i += nwarps) {
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y - 1); // to prevent out-of-bounds memory accesses
const block_q8_1* by0 = &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + kbxd];
const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}
#pragma unroll
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF / QI8_1)) % mmq_x;
const auto kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
const int col_y_eff = min(col_y_0 + ids, ncols_y - 1);
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const half2* dsi_src =
&y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby].ds;
half2* dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF / QI8_1) + kby];
if (need_sum) {
*dsi_dst = *dsi_src;
} else {
float* dfi_dst = (float*)dsi_dst;
*dfi_dst = __low2float(*dsi_src);
}
}
__syncthreads();
// #pragma unroll // unrolling this loop causes too much register pressure
for (int k = ir * WARP_SIZE_GGUF / qr; k < (ir + 1) * WARP_SIZE_GGUF / qr; k += vdr) {
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
sum[i / WARP_SIZE_GGUF][j / nwarps] += vec_dot(
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, threadIdx.x + i, threadIdx.y + j, k);
}
}
}
__syncthreads();
}
}
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
const auto col_dst = col_dst_0 + j + threadIdx.y;
if (col_dst >= ncols_dst) {
return;
}
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
const auto row_dst = row_dst_0 + threadIdx.x + i;
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE_GGUF][j / nwarps];
}
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_0 64
#define MMQ_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MMQ_X_Q4_0 4
#define MMQ_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
#endif
mul_mat_q4_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_0;
const int mmq_y = MMQ_Y_Q4_0;
const int nwarps = NWARPS_Q4_0;
mul_mat_q<
scalar_t,
QK4_0,
QR4_0,
QI4_0,
true,
block_q4_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>,
VDR_Q4_0_Q8_1_MMQ,
vec_dot_q4_0_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q4_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_0;
int mmq_y = MMQ_Y_Q4_0;
int nwarps = NWARPS_Q4_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_1 64
#define MMQ_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MMQ_X_Q4_1 4
#define MMQ_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
#endif
mul_mat_q4_1(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_1;
const int mmq_y = MMQ_Y_Q4_1;
const int nwarps = NWARPS_Q4_1;
mul_mat_q<
scalar_t,
QK4_1,
QR4_1,
QI4_1,
true,
block_q4_1,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>,
VDR_Q4_1_Q8_1_MMQ,
vec_dot_q4_1_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q4_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_1;
int mmq_y = MMQ_Y_Q4_1;
int nwarps = NWARPS_Q4_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_1<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_1<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_0 64
#define MMQ_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MMQ_X_Q5_0 4
#define MMQ_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
#endif
mul_mat_q5_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
mul_mat_q<
scalar_t,
QK5_0,
QR5_0,
QI5_0,
false,
block_q5_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>,
VDR_Q5_0_Q8_1_MMQ,
vec_dot_q5_0_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q5_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_1 64
#define MMQ_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MMQ_X_Q5_1 4
#define MMQ_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
#endif
mul_mat_q5_1(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
mul_mat_q<
scalar_t,
QK5_1,
QR5_1,
QI5_1,
true,
block_q5_1,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>,
VDR_Q5_1_Q8_1_MMQ,
vec_dot_q5_1_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q5_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_1<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_1<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q8_0 64
#define MMQ_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MMQ_X_Q8_0 4
#define MMQ_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
#endif
mul_mat_q8_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
mul_mat_q<
scalar_t,
QK8_0,
QR8_0,
QI8_0,
false,
block_q8_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>,
VDR_Q8_0_Q8_1_MMQ,
vec_dot_q8_0_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q8_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q8_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q8_0<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q2_K 64
#define MMQ_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MMQ_X_Q2_K 4
#define MMQ_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
#endif
mul_mat_q2_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
mul_mat_q<
scalar_t,
QK_K,
QR2_K,
QI2_K,
false,
block_q2_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>,
VDR_Q2_K_Q8_1_MMQ,
vec_dot_q2_K_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q2_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q2_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q2_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q3_K 64
#define MMQ_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MMQ_X_Q3_K 4
#define MMQ_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
#endif
mul_mat_q3_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
mul_mat_q<
scalar_t,
QK_K,
QR3_K,
QI3_K,
false,
block_q3_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>,
VDR_Q3_K_Q8_1_MMQ,
vec_dot_q3_K_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q3_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q3_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q3_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_K 64
#define MMQ_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MMQ_X_Q4_K 4
#define MMQ_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
#endif
mul_mat_q4_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
mul_mat_q<
scalar_t,
QK_K,
QR4_K,
QI4_K,
true,
block_q4_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>,
VDR_Q4_K_Q8_1_MMQ,
vec_dot_q4_K_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q4_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q4_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q4_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_K 64
#define MMQ_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MMQ_X_Q5_K 4
#define MMQ_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
#endif
mul_mat_q5_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
mul_mat_q<
scalar_t,
QK_K,
QR5_K,
QI5_K,
true,
block_q5_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>,
VDR_Q5_K_Q8_1_MMQ,
vec_dot_q5_K_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q5_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q5_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q5_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q6_K 64
#define MMQ_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MMQ_X_Q6_K 4
#define MMQ_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
#endif
mul_mat_q6_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
mul_mat_q<
scalar_t,
QK_K,
QR6_K,
QI6_K,
false,
block_q6_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>,
VDR_Q6_K_Q8_1_MMQ,
vec_dot_q6_K_q8_1_mul_mat>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
template <typename scalar_t>
static void ggml_mul_mat_q6_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
cudaStream_t stream) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
const bool need_check = false;
mul_mat_q6_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
} else {
const bool need_check = true;
mul_mat_q6_K<scalar_t, need_check>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/mmvq.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols,
const int nrows,
const int nvecs) {
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
const auto vec = blockIdx.y;
if (row >= nrows || vec >= nvecs) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
const int nrows_y = (ncols + 512 - 1) / 512 * 512;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t* x = (const block_q_t*)vx;
const block_q8_1* y = (const block_q8_1*)vy;
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = vec * (nrows_y / QK8_1) + i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += SGLANG_SHFL_XOR_SYNC(uint32_t(-1), tmp, mask);
}
if (threadIdx.x == 0) {
dst[vec * nrows + row] = tmp;
}
}
template <typename scalar_t>
static void mul_mat_vec_q4_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q4_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q8_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q2_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q3_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q4_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q6_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq1_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq1_m_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/moe.cuh
#include <cstdint>
/* Adapted from ./csrc/quantization/gguf/mmq.cuh
*/
template <
typename scalar_t,
int qk,
int qr,
int qi,
bool need_sum,
typename block_q_t,
int mmq_x,
int mmq_y,
int nwarps,
allocate_tiles_cuda_t allocate_tiles,
load_tiles_cuda_t load_tiles,
int vdr,
vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void moe_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* __restrict__ sorted_token_ids,
const int* __restrict__ expert_ids,
const int* __restrict__ num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = WARP_SIZE_GGUF / qi;
const int ncols_dst = ncols_y * top_k;
const auto row_dst_0 = blockIdx.x * mmq_y;
const int& row_x_0 = row_dst_0;
const auto col_dst_0 = blockIdx.y * mmq_x;
int token_offs[mmq_x / nwarps];
for (int i = 0; i < mmq_x; i += nwarps) {
token_offs[i / nwarps] = sorted_token_ids[col_dst_0 + threadIdx.y + i];
}
const int exp_idx = expert_ids[blockIdx.y];
if (exp_idx > 255 || exp_idx < 0) return;
if (blockIdx.y * mmq_x > num_tokens_post_padded[0]) return;
const block_q_t* x = (const block_q_t*)((char*)vx + exp_idx * exp_stride);
const block_q8_1* y = (const block_q8_1*)(vy);
int* tile_x_ql = nullptr;
half2* tile_x_dm = nullptr;
int* tile_x_qh = nullptr;
int* tile_x_sc = nullptr;
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF / QI8_1];
float sum[mmq_y / WARP_SIZE_GGUF][mmq_x / nwarps] = {{0.0f}};
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
load_tiles(
x + row_x_0 * blocks_per_row_x + ib0,
tile_x_ql,
tile_x_dm,
tile_x_qh,
tile_x_sc,
threadIdx.y,
nrows_x - row_x_0 - 1,
threadIdx.x,
blocks_per_row_x);
const int n_per_r = ((qk * blocks_per_warp) / qr);
#pragma unroll
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
const auto kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;
#pragma unroll
for (int i = 0; i < mmq_x; i += nwarps) {
const int col_y_eff = token_offs[i / nwarps] / top_k;
const int block_x = ib0 * (qk / QK8_1) + kbxd;
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
const block_q8_1* by0 = &y[col_y_eff * blocks_per_col_y + block_x];
const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}
}
if (threadIdx.x < n_per_r / QK8_1) {
const auto kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
const int col_y_eff = token_offs[threadIdx.y] / top_k;
const int block_x = ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
const half2* dsi_src = &y[col_y_eff * blocks_per_col_y + block_x].ds;
half2* dsi_dst = &tile_y_ds[threadIdx.y * (WARP_SIZE_GGUF / QI8_1) + kby];
if (need_sum) {
*dsi_dst = *dsi_src;
} else {
float* dfi_dst = (float*)dsi_dst;
*dfi_dst = __low2float(*dsi_src);
}
}
}
__syncthreads();
// #pragma unroll // unrolling this loop causes too much register pressure
for (int k = ir * WARP_SIZE_GGUF / qr; k < (ir + 1) * WARP_SIZE_GGUF / qr; k += vdr) {
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
sum[i / WARP_SIZE_GGUF][j / nwarps] += vec_dot(
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, threadIdx.x + i, threadIdx.y + j, k);
}
}
}
__syncthreads();
}
}
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
const int col_dst = token_offs[j / nwarps];
if (col_dst >= ncols_dst) {
return;
}
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
const auto row_dst = row_dst_0 + threadIdx.x + i;
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE_GGUF][j / nwarps];
}
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_0 8
#define MOE_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MOE_X_Q4_0 4
#define MOE_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
#endif
moe_q4_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q4_0;
const int mmq_y = MOE_Y_Q4_0;
const int nwarps = NWARPS_Q4_0;
moe_q<
scalar_t,
QK4_0,
QR4_0,
QI4_0,
true,
block_q4_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, need_check>,
VDR_Q4_0_Q8_1_MMQ,
vec_dot_q4_0_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q4_0_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
int mmq_x = MOE_X_Q4_0;
int mmq_y = MOE_Y_Q4_0;
int nwarps = NWARPS_Q4_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_1 8
#define MOE_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MOE_X_Q4_1 4
#define MOE_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
#endif
moe_q4_1(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q4_1;
const int mmq_y = MOE_Y_Q4_1;
const int nwarps = NWARPS_Q4_1;
moe_q<
scalar_t,
QK4_1,
QR4_1,
QI4_1,
true,
block_q4_1,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, need_check>,
VDR_Q4_1_Q8_1_MMQ,
vec_dot_q4_1_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q4_1_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
int mmq_x = MOE_X_Q4_1;
int mmq_y = MOE_Y_Q4_1;
int nwarps = NWARPS_Q4_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_0 8
#define MOE_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MOE_X_Q5_0 4
#define MOE_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
#endif
moe_q5_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
moe_q<
scalar_t,
QK5_0,
QR5_0,
QI5_0,
false,
block_q5_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, need_check>,
VDR_Q5_0_Q8_1_MMQ,
vec_dot_q5_0_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q5_0_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_1 8
#define MOE_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MOE_X_Q5_1 4
#define MOE_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
#endif
moe_q5_1(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
moe_q<
scalar_t,
QK5_1,
QR5_1,
QI5_1,
true,
block_q5_1,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, need_check>,
VDR_Q5_1_Q8_1_MMQ,
vec_dot_q5_1_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q5_1_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q8_0 8
#define MOE_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MOE_X_Q8_0 4
#define MOE_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
#endif
moe_q8_0(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
moe_q<
scalar_t,
QK8_0,
QR8_0,
QI8_0,
false,
block_q8_0,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, need_check>,
VDR_Q8_0_Q8_1_MMQ,
vec_dot_q8_0_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q8_0_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q2_K 8
#define MOE_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MOE_X_Q2_K 4
#define MOE_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
#endif
moe_q2_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
moe_q<
scalar_t,
QK_K,
QR2_K,
QI2_K,
false,
block_q2_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, need_check>,
VDR_Q2_K_Q8_1_MMQ,
vec_dot_q2_K_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q2_K_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q3_K 8
#define MOE_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MOE_X_Q3_K 4
#define MOE_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
#endif
moe_q3_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
moe_q<
scalar_t,
QK_K,
QR3_K,
QI3_K,
false,
block_q3_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, need_check>,
VDR_Q3_K_Q8_1_MMQ,
vec_dot_q3_K_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q3_K_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q4_K 8
#define MOE_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MOE_X_Q4_K 4
#define MOE_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
#endif
moe_q4_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
moe_q<
scalar_t,
QK_K,
QR4_K,
QI4_K,
true,
block_q4_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, need_check>,
VDR_Q4_K_Q8_1_MMQ,
vec_dot_q4_K_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q4_K_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q5_K 8
#define MOE_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MOE_X_Q5_K 4
#define MOE_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
#endif
moe_q5_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
moe_q<
scalar_t,
QK_K,
QR5_K,
QI5_K,
true,
block_q5_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, need_check>,
VDR_Q5_K_Q8_1_MMQ,
vec_dot_q5_K_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q5_K_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
#if defined(USE_ROCM)
#define MOE_X_Q6_K 8
#define MOE_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MOE_X_Q6_K 4
#define MOE_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template <typename scalar_t, bool need_check>
static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
#endif
moe_q6_K(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k) {
const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
moe_q<
scalar_t,
QK_K,
QR6_K,
QI6_K,
false,
block_q6_K,
mmq_x,
mmq_y,
nwarps,
allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, need_check>,
VDR_Q6_K_Q8_1_MMQ,
vec_dot_q6_K_q8_1_mul_mat>(
vx,
vy,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
template <typename scalar_t>
static void ggml_moe_q6_K_q8_1_cuda(
const void* inp,
const void* w,
scalar_t* dst,
const int* sorted_token_ids,
const int* expert_ids,
const int* num_tokens_post_padded,
const int exp_stride,
const int ncols_x,
const int nrows_x,
const int ncols_y,
const int nrows_y,
const int nrows_dst,
const int top_k,
const int tokens_post_padded,
cudaStream_t stream) {
const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (tokens_post_padded) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
if (nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
} else {
constexpr bool need_check = true;
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
w,
inp,
dst,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
exp_stride,
ncols_x,
nrows_x,
ncols_y,
nrows_y,
nrows_dst,
top_k);
}
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/moe_vec.cuh
// copied and adapted from
// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void moe_vec_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* topk_ids,
const int topk,
const int ncols,
const int nrows,
const int token_stride) {
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
const auto token = blockIdx.z / topk;
const auto expert = (topk_ids)[blockIdx.z];
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row;
const block_q8_1* y = (const block_q8_1*)(((const int*)vy) + token * token_stride);
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += SGLANG_SHFL_XOR_SYNC(uint32_t(-1), tmp, mask);
}
if (threadIdx.x == 0) {
dst[blockIdx.z * nrows + row] = tmp;
}
}
template <typename scalar_t>
static void moe_vec_q4_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q8_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q2_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q3_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q6_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_m_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_nl_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/vecdotq.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh
// and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
static __device__ __forceinline__ int get_int_b2(const void* x, const int& i32) {
const uint16_t* x16 = (const uint16_t*)x; // assume at least 2 byte alignment
int x32 = x16[2 * i32 + 0] << 0;
x32 |= x16[2 * i32 + 1] << 16;
return x32;
}
static __device__ __forceinline__ int get_int_b4(const void* x, const int& i32) {
return ((const int*)x)[i32]; // assume at least 4 byte alignment
}
static __device__ __forceinline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 = (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
return x32;
}
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t* x8, const int& i32) {
const uint16_t* x16 = (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte alignment
int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
return x32;
}
static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t* x8, const int& i32) {
return *((const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
}
static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t* x8, const int& i32) {
return *((const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q4_0_q8_1_impl(const int* v, const int* u, const float& d4, const half2& ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2 * i + 0], sumi);
sumi = __dp4a(vi1, u[2 * i + 1], sumi);
}
const float2 ds8f = __half22float2(ds8);
// second part effectively subtracts 8 from each quant value
return d4 * (sumi * ds8f.x - (8 * vdr / QI4_0) * ds8f.y);
#endif
}
#define VDR_Q4_1_Q8_1_MMVQ 2
#define VDR_Q4_1_Q8_1_MMQ 4
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q4_1_q8_1_impl(const int* v, const int* u, const half2& dm4, const half2& ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2 * i + 0], sumi);
sumi = __dp4a(vi1, u[2 * i + 1], sumi);
}
const float2 tmp = __half22float2(__hmul2(dm4, ds8));
const float d4d8 = tmp.x;
const float m4s8 = tmp.y;
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
#endif
}
#define VDR_Q5_0_Q8_1_MMVQ 2
#define VDR_Q5_0_Q8_1_MMQ 4
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q5_0_q8_1_impl(const int* vl, const int* vh, const int* u, const float& d5, const half2& ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2 * i + 0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2 * i + 1], sumi); // SIMD dot product of quantized values
}
const float2 ds8f = __half22float2(ds8);
// second part effectively subtracts 16 from each quant value
return d5 * (sumi * ds8f.x - (16 * vdr / QI5_0) * ds8f.y);
#endif
}
#define VDR_Q5_1_Q8_1_MMVQ 2
#define VDR_Q5_1_Q8_1_MMQ 4
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q5_1_q8_1_impl(const int* vl, const int* vh, const int* u, const half2& dm5, const half2& ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2 * i + 0], sumi); // SIMD dot product of quantized values
int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2 * i + 1], sumi); // SIMD dot product of quantized values
}
const float2 tmp = __half22float2(__hmul2(dm5, ds8));
const float d5d8 = tmp.x;
const float m5s8 = tmp.y;
// scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
return sumi * d5d8 + m5s8 / (QI5_1 / vdr);
#endif
}
#define VDR_Q8_0_Q8_1_MMVQ 2
#define VDR_Q8_0_Q8_1_MMQ 8
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q8_0_q8_1_impl(const int* v, const int* u, const float& d8_0, const float& d8_1) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
}
return d8_0 * d8_1 * sumi;
#endif
}
template <int vdr>
static __device__ __forceinline__ float
vec_dot_q8_1_q8_1_impl(const int* v, const int* u, const half2& dm8, const half2& ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
}
const float2 tmp = __half22float2(__hmul2(dm8, ds8));
const float d8d8 = tmp.x;
const float m8s8 = tmp.y;
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return sumi * d8d8 + m8s8 / (QI8_1 / vdr);
#endif
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2
// contiguous v/x values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
const int& v,
const int* __restrict__ u,
const uint8_t* __restrict__ scales,
const half2& dm2,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
#pragma unroll
for (int i = 0; i < QR2_K; ++i) {
const int sc = scales[2 * i];
const int vi = (v >> (2 * i)) & 0x03030303;
sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
// fill int with 4x m
int m = sc >> 4;
m |= m << 8;
m |= m << 16;
sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
}
const float2 dm2f = __half22float2(dm2);
return dm2f.x * sumf_d - dm2f.y * sumf_m;
#endif
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int* __restrict__ v,
const int* __restrict__ u,
const uint8_t* __restrict__ scales,
const half2& dm2,
const float& d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi_d = 0;
int sumi_m = 0;
#pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1 / 2) {
int sumi_d_sc = 0;
const int sc = scales[i0 / (QI8_1 / 2)];
// fill int with 4x m
int m = sc >> 4;
m |= m << 8;
m |= m << 16;
#pragma unroll
for (int i = i0; i < i0 + QI8_1 / 2; ++i) {
sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
}
sumi_d += sumi_d_sc * (sc & 0xF);
}
const float2 dm2f = __half22float2(dm2);
return d8 * (dm2f.x * sumi_d - dm2f.y * sumi_m);
#endif
}
#define VDR_Q3_K_Q8_1_MMVQ 1
#define VDR_Q3_K_Q8_1_MMQ 2
// contiguous v/x values
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
const int& vl,
const int& vh,
const int* __restrict__ u,
const uint8_t* __restrict__ scales,
const int& scale_offset,
const float& d3,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf = 0.0f;
#pragma unroll
for (int i = 0; i < QR3_K; ++i) {
const int isc = scale_offset + 2 * i;
const int isc_low = isc % (QK_K / 32);
const int sc_shift_low = 4 * (isc / (QK_K / 32));
const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
const int isc_high = isc % (QK_K / 64);
const int sc_shift_high = 2 * (isc / (QK_K / 64));
const int sc_high = ((scales[(QK_K / 32) + isc_high] >> sc_shift_high) & 3) << 4;
const int sc = (sc_low | sc_high) - 32;
const int vil = (vl >> (2 * i)) & 0x03030303;
const int vih = ((vh >> i) << 2) & 0x04040404;
const int vi = __vsubss4(vil, vih);
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d3 * sumf;
#endif
}
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int* __restrict__ v,
const int* __restrict__ u,
const int8_t* __restrict__ scales,
const float& d3,
const float& d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
int sumi = 0;
#pragma unroll
for (int i0 = 0; i0 < QR3_K * VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1 / 2) {
int sumi_sc = 0;
for (int i = i0; i < i0 + QI8_1 / 2; ++i) {
sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
}
sumi += sumi_sc * scales[i0 / (QI8_1 / 2)];
}
return d3 * d8 * sumi;
#endif
}
#define VDR_Q4_K_Q8_1_MMVQ 2
#define VDR_Q4_K_Q8_1_MMQ 8
// contiguous v/x values
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int* __restrict__ v,
const int* __restrict__ u,
const uint8_t* __restrict__ sc,
const uint8_t* __restrict__ m,
const half2& dm4,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
#pragma unroll
for (int i = 0; i < QR4_K; ++i) {
const int v0i = (v[0] >> (4 * i)) & 0x0F0F0F0F;
const int v1i = (v[1] >> (4 * i)) & 0x0F0F0F0F;
const int dot1 = __dp4a(v1i, u[2 * i + 1], __dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2 * i + 1], __dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
}
const float2 dm4f = __half22float2(dm4);
return dm4f.x * sumf_d - dm4f.y * sumf_m;
#endif
}
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int* __restrict__ v,
const int* __restrict__ u,
const uint8_t* __restrict__ sc,
const uint8_t* __restrict__ m,
const half2& dm4,
const half2* __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
#pragma unroll
for (int i = 0; i < QR4_K * VDR_Q4_K_Q8_1_MMQ / QI8_1; ++i) {
int sumi_d = 0;
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F, u[i * QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
sumf_d += ds8f.x * (sc[i] * sumi_d);
sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
}
const float2 dm4f = __half22float2(dm4);
return dm4f.x * sumf_d - dm4f.y * sumf_m;
#endif
}
#define VDR_Q5_K_Q8_1_MMVQ 2
#define VDR_Q5_K_Q8_1_MMQ 8
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int* __restrict__ vl,
const int* __restrict__ vh,
const int* __restrict__ u,
const uint8_t* __restrict__ sc,
const uint8_t* __restrict__ m,
const half2& dm5,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
#pragma unroll
for (int i = 0; i < QR5_K; ++i) {
const int vl0i = (vl[0] >> (4 * i)) & 0x0F0F0F0F;
const int vl1i = (vl[1] >> (4 * i)) & 0x0F0F0F0F;
const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
const int v0i = vl0i | vh0i;
const int v1i = vl1i | vh1i;
const int dot1 = __dp4a(v0i, u[2 * i + 0], __dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2 * i + 0], __dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u
sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]);
}
const float2 dm5f = __half22float2(dm5);
return dm5f.x * sumf_d - dm5f.y * sumf_m;
#endif
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int* __restrict__ v,
const int* __restrict__ u,
const uint8_t* __restrict__ sc,
const uint8_t* __restrict__ m,
const half2& dm4,
const half2* __restrict__ ds8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
float sumf_m = 0.0f;
#pragma unroll
for (int i = 0; i < QR5_K * VDR_Q5_K_Q8_1_MMQ / QI8_1; ++i) {
int sumi_d = 0;
#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j], sumi_d); // SIMD dot product
}
const float2 ds8f = __half22float2(ds8[i]);
sumf_d += ds8f.x * (sc[i] * sumi_d);
sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
}
const float2 dm4f = __half22float2(dm4);
return dm4f.x * sumf_d - dm4f.y * sumf_m;
#endif
}
#define VDR_Q6_K_Q8_1_MMVQ 1
#define VDR_Q6_K_Q8_1_MMQ 8
// contiguous v/x values
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
const int& vl,
const int& vh,
const int* __restrict__ u,
const int8_t* __restrict__ scales,
const float& d,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf = 0.0f;
#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
const int sc = scales[4 * i];
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d * sumf;
#endif
}
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int* __restrict__ v,
const int* __restrict__ u,
const int8_t* __restrict__ sc,
const float& d6,
const float* __restrict__ d8) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
float sumf_d = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
#pragma unroll
for (int i = i0; i < i0 + 2; ++i) {
sumi_d.x = __dp4a(v[2 * i + 0], u[2 * i + 0], sumi_d.x); // SIMD dot product
sumi_d.x = __dp4a(v[2 * i + 1], u[2 * i + 1], sumi_d.x); // SIMD dot product
sumi_d.y = __dp4a(v[2 * i + 4], u[2 * i + 4], sumi_d.y); // SIMD dot product
sumi_d.y = __dp4a(v[2 * i + 5], u[2 * i + 5], sumi_d.y); // SIMD dot product
}
sumf_d += d8[i0 / 4] * (sc[i0 / 2 + 0] * sumi_d.x + sc[i0 / 2 + 1] * sumi_d.y);
}
return d6 * sumf_d;
#endif
}
static __device__ __forceinline__ float
vec_dot_q4_0_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q4_0* bq4_0 = (const block_q4_0*)vbq;
int v[VDR_Q4_0_Q8_1_MMVQ];
int u[2 * VDR_Q4_0_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
}
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, __half2float(bq4_0->d), bq8_1->ds);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q4_0(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF / QI4_0) + mmq_y / QI4_0];
*x_ql = tile_x_qs;
*x_dm = (half2*)tile_x_d;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q4_0(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI4_0;
const int kqsx = k % QI4_0;
const block_q4_0* bx0 = (const block_q4_0*)vx;
float* x_dmf = (float*)x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q4_0* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
// x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_0;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q4_0* bxi = bx0 + i * blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE_GGUF / QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d);
}
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
(void)x_qh;
(void)x_sc;
const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2));
const float* x_dmf = (const float*)x_dm;
int u[2 * VDR_Q4_0_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
u[2 * l + 0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2 * l + 1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_0) % WARP_SIZE_GGUF];
}
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>(
&x_ql[i * (WARP_SIZE_GGUF + 1) + k],
u,
x_dmf[i * (WARP_SIZE_GGUF / QI4_0) + i / QI4_0 + k / QI4_0],
y_ds[j * (WARP_SIZE_GGUF / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE_GGUF / QI8_1)]);
}
static __device__ __forceinline__ float
vec_dot_q4_1_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q4_1* bq4_1 = (const block_q4_1*)vbq;
int v[VDR_Q4_1_Q8_1_MMVQ];
int u[2 * VDR_Q4_1_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
}
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q4_1(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + +mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI4_1) + mmq_y / QI4_1];
*x_ql = tile_x_qs;
*x_dm = tile_x_dm;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q4_1(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI4_1;
const int kqsx = k % QI4_1;
const block_q4_1* bx0 = (const block_q4_1*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q4_1* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_1;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q4_1* bxi = bx0 + i * blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE_GGUF / QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
}
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2));
int u[2 * VDR_Q4_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
u[2 * l + 0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2 * l + 1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_1) % WARP_SIZE_GGUF];
}
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>(
&x_ql[i * (WARP_SIZE_GGUF + 1) + k],
u,
x_dm[i * (WARP_SIZE_GGUF / QI4_1) + i / QI4_1 + k / QI4_1],
y_ds[j * (WARP_SIZE_GGUF / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE_GGUF / QI8_1)]);
}
static __device__ __forceinline__ float
vec_dot_q5_0_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q5_0* bq5_0 = (const block_q5_0*)vbq;
int vl[VDR_Q5_0_Q8_1_MMVQ];
int vh[VDR_Q5_0_Q8_1_MMVQ];
int u[2 * VDR_Q5_0_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i);
vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
}
return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, __half2float(bq5_0->d), bq8_1->ds);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q5_0(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2 * WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF / QI5_0) + mmq_y / QI5_0];
*x_ql = tile_x_ql;
*x_dm = (half2*)tile_x_d;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q5_0(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI5_0;
const int kqsx = k % QI5_0;
const block_q5_0* bx0 = (const block_q5_0*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q5_0* bxi = bx0 + i * blocks_per_row + kbx;
const int ql = get_int_from_uint8(bxi->qs, kqsx);
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k + 0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k + 1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_0;
const int kbxd = k % blocks_per_tile_x_row;
float* x_dmf = (float*)x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q5_0* bxi = bx0 + i * blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE_GGUF / QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d);
}
}
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2));
const int index_bx = i * (WARP_SIZE_GGUF / QI5_0) + i / QI5_0 + k / QI5_0;
const float* x_dmf = (const float*)x_dm;
const float* y_df = (const float*)y_ds;
int u[2 * VDR_Q5_0_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
u[2 * l + 0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2 * l + 1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_0) % WARP_SIZE_GGUF];
}
return vec_dot_q8_0_q8_1_impl<QR5_0 * VDR_Q5_0_Q8_1_MMQ>(
&x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k],
u,
x_dmf[index_bx],
y_df[j * (WARP_SIZE_GGUF / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE_GGUF / QI8_1)]);
}
static __device__ __forceinline__ float
vec_dot_q5_1_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q5_1* bq5_1 = (const block_q5_1*)vbq;
int vl[VDR_Q5_1_Q8_1_MMVQ];
int vh[VDR_Q5_1_Q8_1_MMVQ];
int u[2 * VDR_Q5_1_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
}
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q5_1(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2 * WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI5_1) + mmq_y / QI5_1];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q5_1(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI5_1;
const int kqsx = k % QI5_1;
const block_q5_1* bx0 = (const block_q5_1*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q5_1* bxi = bx0 + i * blocks_per_row + kbx;
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k + 0] = qs0;
int qs1 = (ql >> 4) & 0x0F0F0F0F;
qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k + 1] = qs1;
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_1;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q5_1* bxi = bx0 + i * blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE_GGUF / QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
}
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2));
const int index_bx = i * (WARP_SIZE_GGUF / QI5_1) + +i / QI5_1 + k / QI5_1;
int u[2 * VDR_Q5_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
u[2 * l + 0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF];
u[2 * l + 1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_1) % WARP_SIZE_GGUF];
}
return vec_dot_q8_1_q8_1_impl<QR5_1 * VDR_Q5_1_Q8_1_MMQ>(
&x_ql[i * (2 * WARP_SIZE_GGUF + 1) + 2 * k],
u,
x_dm[index_bx],
y_ds[j * (WARP_SIZE_GGUF / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE_GGUF / QI8_1)]);
}
static __device__ __forceinline__ float
vec_dot_q8_0_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q8_0* bq8_0 = (const block_q8_0*)vbq;
int v[VDR_Q8_0_Q8_1_MMVQ];
int u[VDR_Q8_0_Q8_1_MMVQ];
#pragma unroll
for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
}
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, __half2float(bq8_0->d), __low2float(bq8_1->ds));
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q8_0(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF / QI8_0) + mmq_y / QI8_0];
*x_ql = tile_x_qs;
*x_dm = (half2*)tile_x_d;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q8_0(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI8_0;
const int kqsx = k % QI8_0;
float* x_dmf = (float*)x_dm;
const block_q8_0* bx0 = (const block_q8_0*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q8_0* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI8_0;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q8_0* bxi = bx0 + i * blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE_GGUF / QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d);
}
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const float* x_dmf = (const float*)x_dm;
const float* y_df = (const float*)y_ds;
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>(
&x_ql[i * (WARP_SIZE_GGUF + 1) + k],
&y_qs[j * WARP_SIZE_GGUF + k],
x_dmf[i * (WARP_SIZE_GGUF / QI8_0) + i / QI8_0 + k / QI8_0],
y_df[j * (WARP_SIZE_GGUF / QI8_1) + k / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_q2_K_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q2_K* bq2_K = (const block_q2_K*)vbq;
const int bq8_offset = QR2_K * (iqs / QI8_1);
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2);
const uint8_t* scales = bq2_K->scales + scale_offset;
const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
int u[QR2_K];
float d8[QR2_K];
#pragma unroll
for (int i = 0; i < QR2_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
}
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q2_K(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI2_K) + mmq_y / QI2_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF / 4) + mmq_y / 4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q2_K(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI2_K;
const int kqsx = k % QI2_K;
const block_q2_K* bx0 = (const block_q2_K*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q2_K* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI2_K;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q2_K* bxi = bx0 + i * blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE_GGUF / QI2_K) + i / QI2_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF / 4);
if (need_check) {
i = min(i, i_max);
}
const block_q2_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 4)) / (QI2_K / 4);
x_sc[i * (WARP_SIZE_GGUF / 4) + i / 4 + k % (WARP_SIZE_GGUF / 4)] =
get_int_from_uint8_aligned(bxi->scales, k % (QI2_K / 4));
}
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const int kbx = k / QI2_K;
const int ky = (k % QI2_K) * QR2_K;
const float* y_df = (const float*)y_ds;
int v[QR2_K * VDR_Q2_K_Q8_1_MMQ];
const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx * QI2_K + (QI2_K / 2) * (ky / (2 * QI2_K)) + ky % (QI2_K / 2);
const int shift = 2 * ((ky % (2 * QI2_K)) / (QI2_K / 2));
#pragma unroll
for (int l = 0; l < QR2_K * VDR_Q2_K_Q8_1_MMQ; ++l) {
v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
}
const uint8_t* scales = ((const uint8_t*)&x_sc[i * (WARP_SIZE_GGUF / 4) + i / 4 + kbx * 4]) + ky / 4;
const int index_y = j * WARP_SIZE_GGUF + (QR2_K * k) % WARP_SIZE_GGUF;
return vec_dot_q2_K_q8_1_impl_mmq(
v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE_GGUF / QI2_K) + i / QI2_K + kbx], y_df[index_y / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_q3_K_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q3_K* bq3_K = (const block_q3_K*)vbq;
const int bq8_offset = QR3_K * (iqs / (QI3_K / 2));
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2);
const float d = __half2float(bq3_K->d);
const int vl = get_int_from_uint8(bq3_K->qs, iqs);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K / 2)) >> bq8_offset;
int u[QR3_K];
float d8[QR3_K];
#pragma unroll
for (int i = 0; i < QR3_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
}
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q3_K(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI3_K) + mmq_y / QI3_K];
__shared__ int tile_x_qh[mmq_y * (WARP_SIZE_GGUF / 2) + mmq_y / 2];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF / 4) + mmq_y / 4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_qh = tile_x_qh;
*x_sc = tile_x_sc;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q3_K(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI3_K;
const int kqsx = k % QI3_K;
const block_q3_K* bx0 = (const block_q3_K*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q3_K* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI3_K;
const int kbxd = k % blocks_per_tile_x_row;
float* x_dmf = (float*)x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q3_K* bxi = bx0 + i * blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE_GGUF / QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d);
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
int i = i0 + i_offset * 2 + k / (WARP_SIZE_GGUF / 2);
if (need_check) {
i = min(i, i_max);
}
const block_q3_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 2)) / (QI3_K / 2);
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
x_qh[i * (WARP_SIZE_GGUF / 2) + i / 2 + k % (WARP_SIZE_GGUF / 2)] =
~get_int_from_uint8(bxi->hmask, k % (QI3_K / 2));
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF / 4);
if (need_check) {
i = min(i, i_max);
}
const block_q3_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 4)) / (QI3_K / 4);
const int ksc = k % (QI3_K / 4);
const int ksc_low = ksc % (QI3_K / 8);
const int shift_low = 4 * (ksc / (QI3_K / 8));
const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
const int ksc_high = QI3_K / 8;
const int shift_high = 2 * ksc;
const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
x_sc[i * (WARP_SIZE_GGUF / 4) + i / 4 + k % (WARP_SIZE_GGUF / 4)] = sc;
}
}
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const int kbx = k / QI3_K;
const int ky = (k % QI3_K) * QR3_K;
const float* x_dmf = (const float*)x_dm;
const float* y_df = (const float*)y_ds;
const int8_t* scales = ((const int8_t*)(x_sc + i * (WARP_SIZE_GGUF / 4) + i / 4 + kbx * 4)) + ky / 4;
int v[QR3_K * VDR_Q3_K_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < QR3_K * VDR_Q3_K_Q8_1_MMQ; ++l) {
const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx * QI3_K + (QI3_K / 2) * (ky / (2 * QI3_K)) + ky % (QI3_K / 2);
const int shift = 2 * ((ky % 32) / 8);
const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
const int vh = x_qh[i * (WARP_SIZE_GGUF / 2) + i / 2 + kbx * (QI3_K / 2) + (ky + l) % 8] >> ((ky + l) / 8);
const int vlh = (vh << 2) & 0x04040404;
v[l] = __vsubss4(vll, vlh);
}
const int index_y = j * WARP_SIZE_GGUF + (k * QR3_K) % WARP_SIZE_GGUF;
return vec_dot_q3_K_q8_1_impl_mmq(
v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE_GGUF / QI3_K) + i / QI3_K + kbx], y_df[index_y / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_q4_K_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q4_K* bq4_K = (const block_q4_K*)vbq;
int v[2];
int u[2 * QR4_K];
float d8[QR4_K];
// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
const int* q4 = (const int*)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
v[0] = q4[0];
v[1] = q4[4];
const uint16_t* scales = (const uint16_t*)bq4_K->scales;
uint16_t aux[2];
const int j = bq8_offset / 2;
if (j < 2) {
aux[0] = scales[j + 0] & 0x3f3f;
aux[1] = scales[j + 2] & 0x3f3f;
} else {
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
}
const uint8_t* sc = (const uint8_t*)aux;
const uint8_t* m = sc + 2;
for (int i = 0; i < QR4_K; ++i) {
const block_q8_1* bq8i = bq8_1 + bq8_offset + i;
d8[i] = __low2float(bq8i->ds);
const int* q8 = (const int*)bq8i->qs + ((iqs / 2) % 4);
u[2 * i + 0] = q8[0];
u[2 * i + 1] = q8[4];
}
return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q4_K(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI4_K) + mmq_y / QI4_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF / 8) + mmq_y / 8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q4_K(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI4_K; // == 0 if QK_K == 256
const int kqsx = k % QI4_K; // == k if QK_K == 256
const block_q4_K* bx0 = (const block_q4_K*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q4_K* bxi = bx0 + i * blocks_per_row + kbx;
x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q4_K* bxi = bx0 + i * blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE_GGUF / QI4_K) + i / QI4_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF / 8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q4_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 8)) / (QI4_K / 8);
const int* scales = (const int*)bxi->scales;
const int ksc = k % (WARP_SIZE_GGUF / 8);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int scales8 = (scales[(ksc % 2) + (ksc != 0)] >> (4 * (ksc & (ksc / 2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc / 2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + ksc] = scales8;
}
}
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
(void)x_qh;
const uint8_t* sc = ((const uint8_t*)&x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + k / 16]) + 2 * ((k % 16) / 8);
const int index_y = j * WARP_SIZE_GGUF + (QR4_K * k) % WARP_SIZE_GGUF;
return vec_dot_q4_K_q8_1_impl_mmq(
&x_ql[i * (WARP_SIZE_GGUF + 1) + k],
&y_qs[index_y],
sc,
sc + 8,
x_dm[i * (WARP_SIZE_GGUF / QI4_K) + i / QI4_K],
&y_ds[index_y / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_q5_K_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q5_K* bq5_K = (const block_q5_K*)vbq;
int vl[2];
int vh[2];
int u[2 * QR5_K];
float d8[QR5_K];
const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2));
const int* ql = (const int*)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
const int* qh = (const int*)(bq5_K->qh + 4 * ((iqs / 2) % 4));
vl[0] = ql[0];
vl[1] = ql[4];
vh[0] = qh[0] >> bq8_offset;
vh[1] = qh[4] >> bq8_offset;
const uint16_t* scales = (const uint16_t*)bq5_K->scales;
uint16_t aux[2];
const int j = bq8_offset / 2;
if (j < 2) {
aux[0] = scales[j + 0] & 0x3f3f;
aux[1] = scales[j + 2] & 0x3f3f;
} else {
aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
}
const uint8_t* sc = (const uint8_t*)aux;
const uint8_t* m = sc + 2;
#pragma unroll
for (int i = 0; i < QR5_K; ++i) {
const block_q8_1* bq8i = bq8_1 + bq8_offset + i;
d8[i] = __low2float(bq8i->ds);
const int* q8 = (const int*)bq8i->qs + ((iqs / 2) % 4);
u[2 * i + 0] = q8[0];
u[2 * i + 1] = q8[4];
}
return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q5_K(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2 * WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI5_K) + mmq_y / QI5_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF / 8) + mmq_y / 8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q5_K(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI5_K; // == 0 if QK_K == 256
const int kqsx = k % QI5_K; // == k if QK_K == 256
const block_q5_K* bx0 = (const block_q5_K*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q5_K* bxi = bx0 + i * blocks_per_row + kbx;
const int ky = QR5_K * kqsx;
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K / 4));
const int qh0 = ((qh >> (2 * (kqsx / (QI5_K / 4)) + 0)) << 4) & 0x10101010;
const int qh1 = ((qh >> (2 * (kqsx / (QI5_K / 4)) + 1)) << 4) & 0x10101010;
const int kq0 = ky - ky % (QI5_K / 2) + k % (QI5_K / 4) + 0;
const int kq1 = ky - ky % (QI5_K / 2) + k % (QI5_K / 4) + (QI5_K / 4);
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + kq0] = ql0 | qh0;
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + kq1] = ql1 | qh1;
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q5_K* bxi = bx0 + i * blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE_GGUF / QI5_K) + i / QI5_K + kbxd] = bxi->dm;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF / 8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q5_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 8)) / (QI5_K / 8);
const int* scales = (const int*)bxi->scales;
const int ksc = k % (WARP_SIZE_GGUF / 8);
// scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
int scales8 = (scales[(ksc % 2) + (ksc != 0)] >> (4 * (ksc & (ksc / 2)))) & 0x0F0F0F0F; // lower 4 bits
scales8 |= (scales[ksc / 2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + ksc] = scales8;
}
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const uint8_t* sc = ((const uint8_t*)&x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + k / 16]) + 2 * ((k % 16) / 8);
const int index_x = i * (QR5_K * WARP_SIZE_GGUF + 1) + QR5_K * k;
const int index_y = j * WARP_SIZE_GGUF + (QR5_K * k) % WARP_SIZE_GGUF;
return vec_dot_q5_K_q8_1_impl_mmq(
&x_ql[index_x],
&y_qs[index_y],
sc,
sc + 8,
x_dm[i * (WARP_SIZE_GGUF / QI5_K) + i / QI5_K],
&y_ds[index_y / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_q6_K_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_q6_K* bq6_K = (const block_q6_K*)vbq;
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
const int vl = get_int_from_uint8(bq6_K->ql, iqs);
const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
const int8_t* scales = bq6_K->scales + scale_offset;
int u[QR6_K];
float d8[QR6_K];
#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2 * i].qs, iqs % QI8_1);
d8[i] = __low2float(bq8_1[bq8_offset + 2 * i].ds);
}
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, __half2float(bq6_K->d), d8);
}
template <int mmq_y>
static __device__ __forceinline__ void allocate_tiles_q6_K(int** x_ql, half2** x_dm, int** x_qh, int** x_sc) {
__shared__ int tile_x_ql[mmq_y * (2 * WARP_SIZE_GGUF) + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF / QI6_K) + mmq_y / QI6_K];
__shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF / 8) + mmq_y / 8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
template <int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void load_tiles_q6_K(
const void* __restrict__ vx,
int* __restrict__ x_ql,
half2* __restrict__ x_dm,
int* __restrict__ x_qh,
int* __restrict__ x_sc,
const int& i_offset,
const int& i_max,
const int& k,
const int& blocks_per_row) {
const int kbx = k / QI6_K; // == 0 if QK_K == 256
const int kqsx = k % QI6_K; // == k if QK_K == 256
const block_q6_K* bx0 = (const block_q6_K*)vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q6_K* bxi = bx0 + i * blocks_per_row + kbx;
const int ky = QR6_K * kqsx;
const int ql = get_int_from_uint8(bxi->ql, kqsx);
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
const int qh = get_int_from_uint8(bxi->qh, (QI6_K / 4) * (kqsx / (QI6_K / 2)) + kqsx % (QI6_K / 4));
const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K / 2)) / (QI6_K / 4)))) << 4) & 0x30303030;
const int qh1 = (qh >> (2 * ((kqsx % (QI6_K / 2)) / (QI6_K / 4)))) & 0x30303030;
const int kq0 = ky - ky % QI6_K + k % (QI6_K / 2) + 0;
const int kq1 = ky - ky % QI6_K + k % (QI6_K / 2) + (QI6_K / 2);
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
x_ql[i * (2 * WARP_SIZE_GGUF + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
}
const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI6_K; // == 1 if QK_K == 256
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
float* x_dmf = (float*)x_dm;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q6_K* bxi = bx0 + i * blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE_GGUF / QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d);
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF / 8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
}
const block_q6_K* bxi = bx0 + i * blocks_per_row + (k % (WARP_SIZE_GGUF / 8)) / 4;
x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + k % (WARP_SIZE_GGUF / 8)] = get_int_from_int8(bxi->scales, k % (QI6_K / 8));
}
}
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
const int* __restrict__ x_ql,
const half2* __restrict__ x_dm,
const int* __restrict__ x_qh,
const int* __restrict__ x_sc,
const int* __restrict__ y_qs,
const half2* __restrict__ y_ds,
const int& i,
const int& j,
const int& k) {
const float* x_dmf = (const float*)x_dm;
const float* y_df = (const float*)y_ds;
const int8_t* sc = ((const int8_t*)&x_sc[i * (WARP_SIZE_GGUF / 8) + i / 8 + k / 8]);
const int index_x = i * (QR6_K * WARP_SIZE_GGUF + 1) + QR6_K * k;
const int index_y = j * WARP_SIZE_GGUF + (QR6_K * k) % WARP_SIZE_GGUF;
return vec_dot_q6_K_q8_1_impl_mmq(
&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE_GGUF / QI6_K) + i / QI6_K], &y_df[index_y / QI8_1]);
}
static __device__ __forceinline__ float
vec_dot_iq2_xxs_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_iq2_xxs* bq2 = (const block_iq2_xxs*)vbq;
const int ib32 = iqs;
const uint16_t* q2 = bq2->qs + 4 * ib32;
const uint8_t* aux8 = (const uint8_t*)q2;
const int8_t* q8 = bq8_1[ib32].qs;
uint32_t aux32 = q2[2] | (q2[3] << 16);
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint8_t* grid = (const uint8_t*)(iq2xxs_grid + aux8[l]);
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
for (int j = 0; j < 8; ++j) {
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
aux32 >>= 7;
}
const float d = __half2float(bq2->d) * (0.5f + aux32) * __half2float(bq8_1[ib32].ds.x) * 0.25f;
return d * sumi;
}
static __device__ __forceinline__ float
vec_dot_iq2_xs_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
const block_iq2_xs* bq2 = (const block_iq2_xs*)vbq;
const int ib32 = iqs;
const uint16_t* q2 = bq2->qs + 4 * ib32;
const int8_t* q8 = bq8_1[ib32].qs;
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
const uint8_t ls2 = bq2->scales[ib32] >> 4;
int sumi1 = 0;
for (int l = 0; l < 2; ++l) {
const uint8_t* grid = (const uint8_t*)(iq2xs_grid + (q2[l] & 511));
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
for (int j = 0; j < 8; ++j) {
sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
int sumi2 = 0;
for (int l = 2; l < 4; ++l) {
const uint8_t* grid = (const uint8_t*)(iq2xs_grid + (q2[l] & 511));
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
for (int j = 0; j < 8; ++j) {
sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
}
q8 += 8;
}
const float d = __half2float(bq2->d) * __half2float(bq8_1[ib32].ds.x) * 0.25f;
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
}
static __device__ __forceinline__ float
vec_dot_iq2_s_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq2_s* bq2 = (const block_iq2_s*)vbq;
const int ib32 = iqs;
const int8_t* q8 = bq8_1[ib32].qs;
const uint8_t* signs = bq2->qs + QK_K / 8 + 4 * ib32;
const uint8_t ls1 = bq2->scales[ib32] & 0xf;
const uint8_t ls2 = bq2->scales[ib32] >> 4;
int sumi1 = 0;
for (int l = 0; l < 2; ++l) {
const uint32_t* grid =
(const uint32_t*)(iq2s_grid + (bq2->qs[4 * ib32 + l] | ((bq2->qh[ib32] << (8 - 2 * l)) & 0x300)));
const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
sumi1 = __dp4a(grid_l, *((const int*)q8 + 0), sumi1);
sumi1 = __dp4a(grid_h, *((const int*)q8 + 1), sumi1);
q8 += 8;
}
int sumi2 = 0;
for (int l = 2; l < 4; ++l) {
const uint32_t* grid =
(const uint32_t*)(iq2s_grid + (bq2->qs[4 * ib32 + l] | ((bq2->qh[ib32] << (8 - 2 * l)) & 0x300)));
const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid[1] ^ signs1, signs1);
sumi2 = __dp4a(grid_l, *((const int*)q8 + 0), sumi2);
sumi2 = __dp4a(grid_h, *((const int*)q8 + 1), sumi2);
q8 += 8;
}
const float d = __half2float(bq2->d) * __low2float(bq8_1[ib32].ds) * 0.25f;
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
#endif
}
static __device__ __forceinline__ float
vec_dot_iq3_xxs_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_xxs* bq2 = (const block_iq3_xxs*)vbq;
const int ib32 = iqs;
const uint8_t* q3 = bq2->qs + 8 * ib32;
const uint16_t* gas = (const uint16_t*)(bq2->qs + QK_K / 4) + 2 * ib32;
const int8_t* q8 = bq8_1[ib32].qs;
uint32_t aux32 = gas[0] | (gas[1] << 16);
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint32_t* grid1 = iq3xxs_grid + q3[2 * l + 0];
const uint32_t* grid2 = iq3xxs_grid + q3[2 * l + 1];
const uint32_t* signs = (const uint32_t*)(ksigns64 + (aux32 & 127));
const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]);
const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]);
sumi = __dp4a(grid_l, *((int*)q8 + 0), sumi);
sumi = __dp4a(grid_h, *((int*)q8 + 1), sumi);
q8 += 8;
aux32 >>= 7;
}
const float d = __half2float(bq2->d) * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f;
return d * sumi;
#endif
}
static __device__ __forceinline__ float
vec_dot_iq3_s_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq3_s* bq2 = (const block_iq3_s*)vbq;
const int ib32 = iqs;
const uint8_t* qs = bq2->qs + 8 * ib32;
const int8_t* q8 = bq8_1[ib32].qs;
int sumi = 0;
for (int l = 0; l < 4; ++l) {
const uint32_t* grid1 = iq3xs_grid + (qs[2 * l + 0] | ((bq2->qh[ib32] << (8 - 2 * l)) & 256));
const uint32_t* grid2 = iq3xs_grid + (qs[2 * l + 1] | ((bq2->qh[ib32] << (7 - 2 * l)) & 256));
uint32_t signs0 = __vcmpeq4(((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
sumi = __dp4a(grid_l, *((int*)q8 + 0), sumi);
sumi = __dp4a(grid_h, *((int*)q8 + 1), sumi);
q8 += 8;
}
const float d = __half2float(bq2->d) * (0.5f + ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
__low2float(bq8_1[ib32].ds) * 0.5f;
return d * sumi;
#endif
}
static __device__ __forceinline__ float
vec_dot_iq1_s_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq1_s* bq1 = (const block_iq1_s*)vbq;
const int qs_packed = get_int_b2(bq1->qs, iqs);
const uint8_t* qs = (const uint8_t*)&qs_packed;
const int qh = bq1->qh[iqs];
int sumi = 0;
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int grid = iq1s_grid_gpu[qs[l0 / 2] | (((qh >> 3 * (l0 / 2)) & 0x07) << 8)];
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
sumi = __dp4a(grid0, u0, sumi);
sumi = __dp4a(grid1, u1, sumi);
}
const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f * IQ1S_DELTA / 0x8000);
const float2 ds = __half22float2(bq8_1[iqs].ds);
return d1q * (ds.x * sumi + ds.y * delta);
#endif
}
static __device__ __forceinline__ float
vec_dot_iq1_m_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq1_m* bq1 = (const block_iq1_m*)vbq;
const int qs_packed = get_int_b4(bq1->qs, iqs);
const uint8_t* qs = (const uint8_t*)&qs_packed;
int sumi[2] = {0};
float sumf[2] = {0.0f};
#pragma unroll
for (int l0 = 0; l0 < 8; l0 += 2) {
const int qhl = bq1->qh[2 * iqs + l0 / 4] >> (4 * ((l0 / 2) % 2));
const int grid = iq1s_grid_gpu[qs[l0 / 2] | ((qhl & 0x07) << 8)];
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
sumi[l0 / 4] = __dp4a(grid0, u0, sumi[l0 / 4]);
sumi[l0 / 4] = __dp4a(grid1, u1, sumi[l0 / 4]);
const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f * IQ1M_DELTA / 0x08);
int sumy = 0;
sumy = __dp4a(u0, 0x01010101, sumy);
sumy = __dp4a(u1, 0x01010101, sumy);
sumf[l0 / 4] += delta * sumy;
}
const uint16_t* sc = (const uint16_t*)bq1->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
const int tmp = sc[iqs / 2] >> (6 * (iqs % 2));
const int sc0 = 2 * ((tmp >> 0) & 0x07) + 1;
const int sc1 = 2 * ((tmp >> 3) & 0x07) + 1;
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
#endif
}
static __device__ __forceinline__ void
get_int_from_table_16(const uint32_t& q4, const uint8_t* values, int& val1, int& val2) {
uint32_t aux32;
const uint8_t* q8 = (const uint8_t*)&aux32;
aux32 = q4 & 0x0f0f0f0f;
uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
val1 = v1 | (v2 << 16);
aux32 = (q4 >> 4) & 0x0f0f0f0f;
v1 = values[q8[0]] | (values[q8[1]] << 8);
v2 = values[q8[2]] | (values[q8[3]] << 8);
val2 = v1 | (v2 << 16);
}
static __device__ __forceinline__ float
vec_dot_iq4_nl_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq4_nl* bq = (const block_iq4_nl*)vbq;
const uint16_t* q4 = (const uint16_t*)bq->qs + 2 * iqs;
const int32_t* q8 = (const int32_t*)bq8_1->qs + iqs;
const uint8_t* values = (const uint8_t*)kvalues_iq4nl;
int v1, v2;
int sumi1 = 0, sumi2 = 0;
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const uint32_t aux = q4[2 * l] | (q4[2 * l + 1] << 16);
get_int_from_table_16(aux, values, v1, v2);
sumi1 = __dp4a(v1, q8[l + 0], sumi1);
sumi2 = __dp4a(v2, q8[l + 4], sumi2);
}
const float d = __half2float(bq->d) * __low2float(bq8_1->ds);
return d * (sumi1 + sumi2);
#endif
}
static __device__ __forceinline__ float
vec_dot_iq4_xs_q8_1(const void* __restrict__ vbq, const block_q8_1* __restrict__ bq8_1, const int& iqs) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM
const block_iq4_xs* bq4 = (const block_iq4_xs*)vbq;
const uint8_t* values = (const uint8_t*)kvalues_iq4nl;
// iqs is 0...7
const int ib32 = iqs;
const int32_t* q8 = (const int*)bq8_1[ib32].qs;
const uint32_t* q4 = (const uint32_t*)bq4->qs + 4 * ib32;
const int8_t ls = ((bq4->scales_l[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf) | (((bq4->scales_h >> 2 * ib32) & 3) << 4);
const float d = __half2float(bq4->d) * (ls - 32) * __low2float(bq8_1[ib32].ds);
int v1, v2;
int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 4; ++j) {
get_int_from_table_16(q4[j], values, v1, v2);
sumi1 = __dp4a(v1, q8[j + 0], sumi1);
sumi2 = __dp4a(v2, q8[j + 4], sumi2);
}
return d * (sumi1 + sumi2);
#endif
}
......@@ -186,6 +186,32 @@ void fast_topk_transform_interface(
void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif
/*
* From gguf quantization
*/
torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/*
* From csrc/gemm
*/
......@@ -306,6 +332,8 @@ void topk_softmax(
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
......
......@@ -19,6 +19,10 @@ limitations under the License.
#include <cuda_runtime.h>
#include <torch/all.h>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifdef USE_ROCM
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
#define _DISPATCH_CASE_F16(c_type, ...) \
......@@ -326,6 +330,13 @@ inline bool getEnvEnablePDL() {
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define DISPATCH_CASE_FLOAT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
......@@ -447,3 +458,12 @@ inline uint32_t next_pow2(uint32_t x) noexcept {
if (x <= 1) return 1;
return 1u << (32 - __builtin_clz(x - 1));
}
/*
* LDG Support
*/
#ifndef USE_ROCM
#define SGLANG_LDG(arg) __ldg(arg)
#else
#define SGLANG_LDG(arg) *(arg)
#endif
......@@ -288,10 +288,19 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
moe_fused_gate,
moe_sum,
moe_sum_reduce,
prepare_moe_input,
topk_softmax,
)
from sgl_kernel.quantization import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
from sgl_kernel.sampling import (
min_p_sampling_from_probs,
top_k_mask_logits,
......
......@@ -48,6 +48,16 @@ def moe_sum_reduce(
)
def moe_sum(
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
):
torch.ops.sgl_kernel.moe_sum.default(
input_tensor,
output_tensor,
)
def moe_fused_gate(
input_tensor,
bias,
......
from .gguf import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
import torch
def ggml_dequantize(
weight: torch.Tensor, quant_type: int, M: int, N: int, dtype: torch.dtype
):
assert M > 0 and N > 0, "GGUF weight Input shape must be of positive dimensions"
return torch.ops.sgl_kernel.ggml_dequantize.default(weight, quant_type, M, N, dtype)
def ggml_mul_mat_vec_a8(
weight: torch.Tensor, x: torch.Tensor, quant_type: int, row: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_mul_mat_vec_a8.default(weight, x, quant_type, row)
def ggml_mul_mat_a8(
weight: torch.Tensor, x: torch.Tensor, quant_type: int, row: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_mul_mat_a8.default(weight, x, quant_type, row)
def ggml_moe_a8(
input: torch.Tensor,
weight: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_token_post_padded: torch.Tensor,
type: int,
row: int,
topk: int,
tokens: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_moe_a8.default(
input,
weight,
sorted_token_ids,
expert_ids,
num_token_post_padded,
type,
row,
topk,
tokens,
)
def ggml_moe_a8_vec(
input: torch.Tensor,
weight: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
type: int,
row: int,
tokens: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_moe_a8_vec.default(
input, weight, topk_ids, top_k, type, row, tokens
)
def ggml_moe_get_block_size(type: int) -> int:
return torch.ops.sgl_kernel.ggml_moe_get_block_size.default(type)
# SPDX-License-Identifier: Apache-2.0
import random
from pathlib import Path
import numpy as np
import pytest
import torch
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download
from sgl_kernel import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors(
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
def get_gguf_MoE_tensors(
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024]
NUM_TOKENS = [7, 2050] # Arbitrary values for testing
SEEDS = [0]
QUANT_TYPES = [
# i-matrix
GGMLQuantizationType.IQ1_M,
GGMLQuantizationType.IQ1_S,
GGMLQuantizationType.IQ2_S,
GGMLQuantizationType.IQ2_XS,
GGMLQuantizationType.IQ3_S,
GGMLQuantizationType.IQ3_XXS,
GGMLQuantizationType.IQ4_NL,
GGMLQuantizationType.IQ4_XS,
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quantization
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
]
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))
ref_output = torch.tensor(
dequantize(tensor.data, quant_type), device="cuda"
).to(dtype)
output = ggml_dequantize(
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
],
)
@torch.inference_mode()
def test_mmq(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
)
if __name__ == "__main__":
pytest.main([__file__])
......@@ -4,7 +4,14 @@ import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size
from sgl_kernel import moe_align_block_size, moe_sum
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def ceil_div(a, b):
......@@ -246,5 +253,20 @@ def test_moe_align_block_size_compare_implementations(
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(_is_hip, reason="Skip for AMD GPU")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment