Unverified Commit a6081600 authored by Szymon Ożóg's avatar Szymon Ożóg Committed by GitHub
Browse files

[Kernel] Fix conflicting macro names for gguf kernels (#15456)


Signed-off-by: default avatarSzymonOzog <szymon.ozog@gmail.com>
parent 3f04a7fb
...@@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input ...@@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
int64_t ggml_moe_get_block_size(int64_t type) { int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) { switch (type) {
case 2: case 2:
return MMQ_X_Q4_0; return MOE_X_Q4_0;
case 3: case 3:
return MMQ_X_Q4_1; return MOE_X_Q4_1;
case 6: case 6:
return MMQ_X_Q5_0; return MOE_X_Q5_0;
case 7: case 7:
return MMQ_X_Q5_1; return MOE_X_Q5_1;
case 8: case 8:
return MMQ_X_Q8_0; return MOE_X_Q8_0;
case 10: case 10:
return MMQ_X_Q2_K; return MOE_X_Q2_K;
case 11: case 11:
return MMQ_X_Q3_K; return MOE_X_Q3_K;
case 12: case 12:
return MMQ_X_Q4_K; return MOE_X_Q4_K;
case 13: case 13:
return MMQ_X_Q5_K; return MOE_X_Q5_K;
case 14: case 14:
return MMQ_X_Q6_K; return MOE_X_Q6_K;
} }
return 0; return 0;
} }
...@@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q( ...@@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_0 64 #define MOE_X_Q4_0 64
#define MMQ_Y_Q4_0 128 #define MOE_Y_Q4_0 128
#define NWARPS_Q4_0 8 #define NWARPS_Q4_0 8
#else #else
#define MMQ_X_Q4_0 4 #define MOE_X_Q4_0 4
#define MMQ_Y_Q4_0 32 #define MOE_Y_Q4_0 32
#define NWARPS_Q4_0 4 #define NWARPS_Q4_0 4
#endif #endif
...@@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2) ...@@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_0; const int mmq_x = MOE_X_Q4_0;
const int mmq_y = MMQ_Y_Q4_0; const int mmq_y = MOE_Y_Q4_0;
const int nwarps = NWARPS_Q4_0; const int nwarps = NWARPS_Q4_0;
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
...@@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda( ...@@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_0; int mmq_x = MOE_X_Q4_0;
int mmq_y = MMQ_Y_Q4_0; int mmq_y = MOE_Y_Q4_0;
int nwarps = NWARPS_Q4_0; int nwarps = NWARPS_Q4_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda( ...@@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_1 64 #define MOE_X_Q4_1 64
#define MMQ_Y_Q4_1 128 #define MOE_Y_Q4_1 128
#define NWARPS_Q4_1 8 #define NWARPS_Q4_1 8
#else #else
#define MMQ_X_Q4_1 4 #define MOE_X_Q4_1 4
#define MMQ_Y_Q4_1 32 #define MOE_Y_Q4_1 32
#define NWARPS_Q4_1 4 #define NWARPS_Q4_1 4
#endif #endif
...@@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2) ...@@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_1; const int mmq_x = MOE_X_Q4_1;
const int mmq_y = MMQ_Y_Q4_1; const int mmq_y = MOE_Y_Q4_1;
const int nwarps = NWARPS_Q4_1; const int nwarps = NWARPS_Q4_1;
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
...@@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda( ...@@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_1; int mmq_x = MOE_X_Q4_1;
int mmq_y = MMQ_Y_Q4_1; int mmq_y = MOE_Y_Q4_1;
int nwarps = NWARPS_Q4_1; int nwarps = NWARPS_Q4_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda( ...@@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_0 64 #define MOE_X_Q5_0 64
#define MMQ_Y_Q5_0 128 #define MOE_Y_Q5_0 128
#define NWARPS_Q5_0 8 #define NWARPS_Q5_0 8
#else #else
#define MMQ_X_Q5_0 4 #define MOE_X_Q5_0 4
#define MMQ_Y_Q5_0 32 #define MOE_Y_Q5_0 32
#define NWARPS_Q5_0 4 #define NWARPS_Q5_0 4
#endif #endif
...@@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2) ...@@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_0; const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0; const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0; const int nwarps = NWARPS_Q5_0;
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
...@@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda( ...@@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_0; const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0; const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0; const int nwarps = NWARPS_Q5_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda( ...@@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_1 64 #define MOE_X_Q5_1 64
#define MMQ_Y_Q5_1 128 #define MOE_Y_Q5_1 128
#define NWARPS_Q5_1 8 #define NWARPS_Q5_1 8
#else #else
#define MMQ_X_Q5_1 4 #define MOE_X_Q5_1 4
#define MMQ_Y_Q5_1 32 #define MOE_Y_Q5_1 32
#define NWARPS_Q5_1 4 #define NWARPS_Q5_1 4
#endif #endif
...@@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2) ...@@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_1; const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1; const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1; const int nwarps = NWARPS_Q5_1;
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
...@@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda( ...@@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_1; const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1; const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1; const int nwarps = NWARPS_Q5_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda( ...@@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q8_0 64 #define MOE_X_Q8_0 64
#define MMQ_Y_Q8_0 128 #define MOE_Y_Q8_0 128
#define NWARPS_Q8_0 8 #define NWARPS_Q8_0 8
#else #else
#define MMQ_X_Q8_0 4 #define MOE_X_Q8_0 4
#define MMQ_Y_Q8_0 32 #define MOE_Y_Q8_0 32
#define NWARPS_Q8_0 4 #define NWARPS_Q8_0 4
#endif #endif
...@@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2) ...@@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q8_0; const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0; const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0; const int nwarps = NWARPS_Q8_0;
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
...@@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda( ...@@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q8_0; const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0; const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0; const int nwarps = NWARPS_Q8_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda( ...@@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q2_K 64 #define MOE_X_Q2_K 64
#define MMQ_Y_Q2_K 128 #define MOE_Y_Q2_K 128
#define NWARPS_Q2_K 8 #define NWARPS_Q2_K 8
#else #else
#define MMQ_X_Q2_K 4 #define MOE_X_Q2_K 4
#define MMQ_Y_Q2_K 32 #define MOE_Y_Q2_K 32
#define NWARPS_Q2_K 4 #define NWARPS_Q2_K 4
#endif #endif
...@@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2) ...@@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q2_K; const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K; const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K; const int nwarps = NWARPS_Q2_K;
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
...@@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda( ...@@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q2_K; const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K; const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K; const int nwarps = NWARPS_Q2_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda( ...@@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q3_K 64 #define MOE_X_Q3_K 64
#define MMQ_Y_Q3_K 128 #define MOE_Y_Q3_K 128
#define NWARPS_Q3_K 8 #define NWARPS_Q3_K 8
#else #else
#define MMQ_X_Q3_K 4 #define MOE_X_Q3_K 4
#define MMQ_Y_Q3_K 32 #define MOE_Y_Q3_K 32
#define NWARPS_Q3_K 4 #define NWARPS_Q3_K 4
#endif #endif
...@@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2) ...@@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q3_K; const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K; const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K; const int nwarps = NWARPS_Q3_K;
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
...@@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda( ...@@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q3_K; const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K; const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K; const int nwarps = NWARPS_Q3_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda( ...@@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q4_K 64 #define MOE_X_Q4_K 64
#define MMQ_Y_Q4_K 128 #define MOE_Y_Q4_K 128
#define NWARPS_Q4_K 8 #define NWARPS_Q4_K 8
#else #else
#define MMQ_X_Q4_K 4 #define MOE_X_Q4_K 4
#define MMQ_Y_Q4_K 32 #define MOE_Y_Q4_K 32
#define NWARPS_Q4_K 4 #define NWARPS_Q4_K 4
#endif #endif
...@@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2) ...@@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q4_K; const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K; const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K; const int nwarps = NWARPS_Q4_K;
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
...@@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda( ...@@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q4_K; const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K; const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K; const int nwarps = NWARPS_Q4_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda( ...@@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q5_K 64 #define MOE_X_Q5_K 64
#define MMQ_Y_Q5_K 128 #define MOE_Y_Q5_K 128
#define NWARPS_Q5_K 8 #define NWARPS_Q5_K 8
#else #else
#define MMQ_X_Q5_K 4 #define MOE_X_Q5_K 4
#define MMQ_Y_Q5_K 32 #define MOE_Y_Q5_K 32
#define NWARPS_Q5_K 4 #define NWARPS_Q5_K 4
#endif #endif
...@@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2) ...@@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q5_K; const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K; const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K; const int nwarps = NWARPS_Q5_K;
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
...@@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda( ...@@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_K; const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K; const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K; const int nwarps = NWARPS_Q5_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
...@@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda( ...@@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
} }
#if defined(USE_ROCM) #if defined(USE_ROCM)
#define MMQ_X_Q6_K 64 #define MOE_X_Q6_K 64
#define MMQ_Y_Q6_K 128 #define MOE_Y_Q6_K 128
#define NWARPS_Q6_K 8 #define NWARPS_Q6_K 8
#else #else
#define MMQ_X_Q6_K 4 #define MOE_X_Q6_K 4
#define MMQ_Y_Q6_K 32 #define MOE_Y_Q6_K 32
#define NWARPS_Q6_K 4 #define NWARPS_Q6_K 4
#endif #endif
...@@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2) ...@@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) { const int top_k) {
const int mmq_x = MMQ_X_Q6_K; const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K; const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K; const int nwarps = NWARPS_Q6_K;
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
...@@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda( ...@@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x, const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k, const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) { const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q6_K; const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K; const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K; const int nwarps = NWARPS_Q6_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment