Unverified Commit 8fdcd98e authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[7/n] decouple quantization impl from vllm dependency - gguf kernel (#11019)

parent b5dcfd41
...@@ -271,6 +271,8 @@ set(SOURCES ...@@ -271,6 +271,8 @@ set(SOURCES
"csrc/elementwise/topk.cu" "csrc/elementwise/topk.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/gemm/awq_kernel.cu" "csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu" "csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu" "csrc/gemm/dsv3_fused_a_gemm.cu"
...@@ -306,6 +308,7 @@ set(SOURCES ...@@ -306,6 +308,7 @@ set(SOURCES
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_sum.cu"
"csrc/moe/moe_sum_reduce.cu" "csrc/moe/moe_sum_reduce.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/nvfp4_blockwise_moe.cu"
......
...@@ -114,6 +114,37 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -114,6 +114,37 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"cu_seqlens_q) -> ()"); "cu_seqlens_q) -> ()");
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface); m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
/*
* From gguf quantiztion
*/
m.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
m.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
m.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) "
"-> Tensor");
m.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
m.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
m.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
m.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
m.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
m.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
m.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -226,17 +257,23 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -226,17 +257,23 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()"); m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()");
m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce); m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce);
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
m.def( m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def( m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor " "stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets, Tensor workspace) -> ()"); "expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm); m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def( m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1," "prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> " " Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <ATen/cuda/Atomic.cuh>
#include <cub/cub.cuh>
#include "utils.h"
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += SGLANG_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
void moe_sum(
torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const auto num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (topk) {
case 2:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 2>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
case 3:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 3>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
case 4:
DISPATCH_FLOAT_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
moe_sum_kernel<scalar_t, 4>
<<<grid, block, 0, stream>>>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/mmvq.cuh
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int ncols,
const int nrows,
const int nvecs) {
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
const auto vec = blockIdx.y;
if (row >= nrows || vec >= nvecs) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
const int nrows_y = (ncols + 512 - 1) / 512 * 512;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t* x = (const block_q_t*)vx;
const block_q8_1* y = (const block_q8_1*)vy;
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = vec * (nrows_y / QK8_1) + i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += SGLANG_SHFL_XOR_SYNC(uint32_t(-1), tmp, mask);
}
if (threadIdx.x == 0) {
dst[vec * nrows + row] = tmp;
}
}
template <typename scalar_t>
static void mul_mat_vec_q4_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q4_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q8_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q2_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q3_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q4_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q5_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_q6_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq2_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq1_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq1_m_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
template <typename scalar_t>
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int ncols,
const int nrows,
const int nvecs,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, nvecs, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs);
}
This diff is collapsed.
// copied from
// https://github.com/vllm-project/vllm/blob/4492e3a55428e161ca8db381edc28263e5da4c8d/csrc/quantization/gguf/moe_vec.cuh
// copied and adapted from
// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void moe_vec_q(
const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* topk_ids,
const int topk,
const int ncols,
const int nrows,
const int token_stride) {
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
const auto token = blockIdx.z / topk;
const auto expert = (topk_ids)[blockIdx.z];
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row;
const block_q8_1* y = (const block_q8_1*)(((const int*)vy) + token * token_stride);
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += SGLANG_SHFL_XOR_SYNC(uint32_t(-1), tmp, mask);
}
if (threadIdx.x == 0) {
dst[blockIdx.z * nrows + row] = tmp;
}
}
template <typename scalar_t>
static void moe_vec_q4_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_1_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q8_0_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q2_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q3_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q6_K_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_xxs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_m_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_nl_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_xs_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_s_q8_1_cuda(
const void* vx,
const void* vy,
scalar_t* dst,
const int* topk_ids,
const int top_k,
const int tokens,
const int ncols,
const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
This diff is collapsed.
...@@ -186,6 +186,32 @@ void fast_topk_transform_interface( ...@@ -186,6 +186,32 @@ void fast_topk_transform_interface(
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
#endif #endif
/*
* From gguf quantization
*/
torch::Tensor
ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n, std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row);
torch::Tensor ggml_moe_a8(
torch::Tensor X,
torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded,
int64_t type,
int64_t row,
int64_t top_k,
int64_t tokens);
torch::Tensor ggml_moe_a8_vec(
torch::Tensor X, torch::Tensor W, torch::Tensor topk_ids, int64_t top_k, int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -306,6 +332,8 @@ void topk_softmax( ...@@ -306,6 +332,8 @@ void topk_softmax(
void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
std::vector<at::Tensor> moe_fused_gate( std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input, at::Tensor& input,
at::Tensor& bias, at::Tensor& bias,
......
...@@ -19,6 +19,10 @@ limitations under the License. ...@@ -19,6 +19,10 @@ limitations under the License.
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/all.h> #include <torch/all.h>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifdef USE_ROCM #ifdef USE_ROCM
// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) // Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491)
#define _DISPATCH_CASE_F16(c_type, ...) \ #define _DISPATCH_CASE_F16(c_type, ...) \
...@@ -326,6 +330,13 @@ inline bool getEnvEnablePDL() { ...@@ -326,6 +330,13 @@ inline bool getEnvEnablePDL() {
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define DISPATCH_CASE_FLOAT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -447,3 +458,12 @@ inline uint32_t next_pow2(uint32_t x) noexcept { ...@@ -447,3 +458,12 @@ inline uint32_t next_pow2(uint32_t x) noexcept {
if (x <= 1) return 1; if (x <= 1) return 1;
return 1u << (32 - __builtin_clz(x - 1)); return 1u << (32 - __builtin_clz(x - 1));
} }
/*
* LDG Support
*/
#ifndef USE_ROCM
#define SGLANG_LDG(arg) __ldg(arg)
#else
#define SGLANG_LDG(arg) *(arg)
#endif
...@@ -288,10 +288,19 @@ from sgl_kernel.moe import ( ...@@ -288,10 +288,19 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm, fp8_blockwise_scaled_grouped_mm,
moe_align_block_size, moe_align_block_size,
moe_fused_gate, moe_fused_gate,
moe_sum,
moe_sum_reduce, moe_sum_reduce,
prepare_moe_input, prepare_moe_input,
topk_softmax, topk_softmax,
) )
from sgl_kernel.quantization import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
from sgl_kernel.sampling import ( from sgl_kernel.sampling import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_mask_logits, top_k_mask_logits,
......
...@@ -48,6 +48,16 @@ def moe_sum_reduce( ...@@ -48,6 +48,16 @@ def moe_sum_reduce(
) )
def moe_sum(
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
):
torch.ops.sgl_kernel.moe_sum.default(
input_tensor,
output_tensor,
)
def moe_fused_gate( def moe_fused_gate(
input_tensor, input_tensor,
bias, bias,
......
from .gguf import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
import torch
def ggml_dequantize(
weight: torch.Tensor, quant_type: int, M: int, N: int, dtype: torch.dtype
):
assert M > 0 and N > 0, "GGUF weight Input shape must be of positive dimensions"
return torch.ops.sgl_kernel.ggml_dequantize.default(weight, quant_type, M, N, dtype)
def ggml_mul_mat_vec_a8(
weight: torch.Tensor, x: torch.Tensor, quant_type: int, row: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_mul_mat_vec_a8.default(weight, x, quant_type, row)
def ggml_mul_mat_a8(
weight: torch.Tensor, x: torch.Tensor, quant_type: int, row: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_mul_mat_a8.default(weight, x, quant_type, row)
def ggml_moe_a8(
input: torch.Tensor,
weight: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_token_post_padded: torch.Tensor,
type: int,
row: int,
topk: int,
tokens: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_moe_a8.default(
input,
weight,
sorted_token_ids,
expert_ids,
num_token_post_padded,
type,
row,
topk,
tokens,
)
def ggml_moe_a8_vec(
input: torch.Tensor,
weight: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
type: int,
row: int,
tokens: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.ggml_moe_a8_vec.default(
input, weight, topk_ids, top_k, type, row, tokens
)
def ggml_moe_get_block_size(type: int) -> int:
return torch.ops.sgl_kernel.ggml_moe_get_block_size.default(type)
# SPDX-License-Identifier: Apache-2.0
import random
from pathlib import Path
import numpy as np
import pytest
import torch
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download
from sgl_kernel import (
ggml_dequantize,
ggml_moe_a8,
ggml_moe_a8_vec,
ggml_moe_get_block_size,
ggml_mul_mat_a8,
ggml_mul_mat_vec_a8,
)
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors(
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
def get_gguf_MoE_tensors(
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024]
NUM_TOKENS = [7, 2050] # Arbitrary values for testing
SEEDS = [0]
QUANT_TYPES = [
# i-matrix
GGMLQuantizationType.IQ1_M,
GGMLQuantizationType.IQ1_S,
GGMLQuantizationType.IQ2_S,
GGMLQuantizationType.IQ2_XS,
GGMLQuantizationType.IQ3_S,
GGMLQuantizationType.IQ3_XXS,
GGMLQuantizationType.IQ4_NL,
GGMLQuantizationType.IQ4_XS,
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quantization
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
]
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))
ref_output = torch.tensor(
dequantize(tensor.data, quant_type), device="cuda"
).to(dtype)
output = ggml_dequantize(
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(dtype)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
],
)
@torch.inference_mode()
def test_mmq(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
)
if __name__ == "__main__":
pytest.main([__file__])
...@@ -4,7 +4,14 @@ import pytest ...@@ -4,7 +4,14 @@ import pytest
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size from sgl_kernel import moe_align_block_size, moe_sum
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def ceil_div(a, b): def ceil_div(a, b):
...@@ -246,5 +253,20 @@ def test_moe_align_block_size_compare_implementations( ...@@ -246,5 +253,20 @@ def test_moe_align_block_size_compare_implementations(
) )
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(_is_hip, reason="Skip for AMD GPU")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)
expected = input.sum(dim=1)
moe_sum(input, actual)
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment