Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a6081600
Unverified
Commit
a6081600
authored
Mar 25, 2025
by
Szymon Ożóg
Committed by
GitHub
Mar 25, 2025
Browse files
[Kernel] Fix conflicting macro names for gguf kernels (#15456)
Signed-off-by:
SzymonOzog
<
szymon.ozog@gmail.com
>
parent
3f04a7fb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
90 deletions
+90
-90
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+10
-10
csrc/quantization/gguf/moe.cuh
csrc/quantization/gguf/moe.cuh
+80
-80
No files found.
csrc/quantization/gguf/gguf_kernel.cu
View file @
a6081600
...
...
@@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
int64_t
ggml_moe_get_block_size
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
M
MQ
_X_Q4_0
;
return
M
OE
_X_Q4_0
;
case
3
:
return
M
MQ
_X_Q4_1
;
return
M
OE
_X_Q4_1
;
case
6
:
return
M
MQ
_X_Q5_0
;
return
M
OE
_X_Q5_0
;
case
7
:
return
M
MQ
_X_Q5_1
;
return
M
OE
_X_Q5_1
;
case
8
:
return
M
MQ
_X_Q8_0
;
return
M
OE
_X_Q8_0
;
case
10
:
return
M
MQ
_X_Q2_K
;
return
M
OE
_X_Q2_K
;
case
11
:
return
M
MQ
_X_Q3_K
;
return
M
OE
_X_Q3_K
;
case
12
:
return
M
MQ
_X_Q4_K
;
return
M
OE
_X_Q4_K
;
case
13
:
return
M
MQ
_X_Q5_K
;
return
M
OE
_X_Q5_K
;
case
14
:
return
M
MQ
_X_Q6_K
;
return
M
OE
_X_Q6_K
;
}
return
0
;
}
csrc/quantization/gguf/moe.cuh
View file @
a6081600
...
...
@@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_0 64
#define M
MQ
_Y_Q4_0 128
#define M
OE
_X_Q4_0 64
#define M
OE
_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define M
MQ
_X_Q4_0 4
#define M
MQ
_Y_Q4_0 32
#define M
OE
_X_Q4_0 4
#define M
OE
_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
...
...
@@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_0
;
const
int
mmq_y
=
M
MQ
_Y_Q4_0
;
const
int
mmq_x
=
M
OE
_X_Q4_0
;
const
int
mmq_y
=
M
OE
_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
moe_q
<
scalar_t
,
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
M
MQ
_X_Q4_0
;
int
mmq_y
=
M
MQ
_Y_Q4_0
;
int
mmq_x
=
M
OE
_X_Q4_0
;
int
mmq_y
=
M
OE
_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_1 64
#define M
MQ
_Y_Q4_1 128
#define M
OE
_X_Q4_1 64
#define M
OE
_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define M
MQ
_X_Q4_1 4
#define M
MQ
_Y_Q4_1 32
#define M
OE
_X_Q4_1 4
#define M
OE
_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
...
...
@@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_1
;
const
int
mmq_y
=
M
MQ
_Y_Q4_1
;
const
int
mmq_x
=
M
OE
_X_Q4_1
;
const
int
mmq_y
=
M
OE
_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
moe_q
<
scalar_t
,
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
M
MQ
_X_Q4_1
;
int
mmq_y
=
M
MQ
_Y_Q4_1
;
int
mmq_x
=
M
OE
_X_Q4_1
;
int
mmq_y
=
M
OE
_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_0 64
#define M
MQ
_Y_Q5_0 128
#define M
OE
_X_Q5_0 64
#define M
OE
_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define M
MQ
_X_Q5_0 4
#define M
MQ
_Y_Q5_0 32
#define M
OE
_X_Q5_0 4
#define M
OE
_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
...
...
@@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_0
;
const
int
mmq_y
=
M
MQ
_Y_Q5_0
;
const
int
mmq_x
=
M
OE
_X_Q5_0
;
const
int
mmq_y
=
M
OE
_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
moe_q
<
scalar_t
,
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_0
;
const
int
mmq_y
=
M
MQ
_Y_Q5_0
;
const
int
mmq_x
=
M
OE
_X_Q5_0
;
const
int
mmq_y
=
M
OE
_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_1 64
#define M
MQ
_Y_Q5_1 128
#define M
OE
_X_Q5_1 64
#define M
OE
_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define M
MQ
_X_Q5_1 4
#define M
MQ
_Y_Q5_1 32
#define M
OE
_X_Q5_1 4
#define M
OE
_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
...
...
@@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_1
;
const
int
mmq_y
=
M
MQ
_Y_Q5_1
;
const
int
mmq_x
=
M
OE
_X_Q5_1
;
const
int
mmq_y
=
M
OE
_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
moe_q
<
scalar_t
,
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_1
;
const
int
mmq_y
=
M
MQ
_Y_Q5_1
;
const
int
mmq_x
=
M
OE
_X_Q5_1
;
const
int
mmq_y
=
M
OE
_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q8_0 64
#define M
MQ
_Y_Q8_0 128
#define M
OE
_X_Q8_0 64
#define M
OE
_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define M
MQ
_X_Q8_0 4
#define M
MQ
_Y_Q8_0 32
#define M
OE
_X_Q8_0 4
#define M
OE
_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
...
...
@@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q8_0
;
const
int
mmq_y
=
M
MQ
_Y_Q8_0
;
const
int
mmq_x
=
M
OE
_X_Q8_0
;
const
int
mmq_y
=
M
OE
_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
moe_q
<
scalar_t
,
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q8_0
;
const
int
mmq_y
=
M
MQ
_Y_Q8_0
;
const
int
mmq_x
=
M
OE
_X_Q8_0
;
const
int
mmq_y
=
M
OE
_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q2_K 64
#define M
MQ
_Y_Q2_K 128
#define M
OE
_X_Q2_K 64
#define M
OE
_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define M
MQ
_X_Q2_K 4
#define M
MQ
_Y_Q2_K 32
#define M
OE
_X_Q2_K 4
#define M
OE
_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
...
...
@@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q2_K
;
const
int
mmq_y
=
M
MQ
_Y_Q2_K
;
const
int
mmq_x
=
M
OE
_X_Q2_K
;
const
int
mmq_y
=
M
OE
_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
moe_q
<
scalar_t
,
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q2_K
;
const
int
mmq_y
=
M
MQ
_Y_Q2_K
;
const
int
mmq_x
=
M
OE
_X_Q2_K
;
const
int
mmq_y
=
M
OE
_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q3_K 64
#define M
MQ
_Y_Q3_K 128
#define M
OE
_X_Q3_K 64
#define M
OE
_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define M
MQ
_X_Q3_K 4
#define M
MQ
_Y_Q3_K 32
#define M
OE
_X_Q3_K 4
#define M
OE
_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
...
...
@@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q3_K
;
const
int
mmq_y
=
M
MQ
_Y_Q3_K
;
const
int
mmq_x
=
M
OE
_X_Q3_K
;
const
int
mmq_y
=
M
OE
_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
moe_q
<
scalar_t
,
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q3_K
;
const
int
mmq_y
=
M
MQ
_Y_Q3_K
;
const
int
mmq_x
=
M
OE
_X_Q3_K
;
const
int
mmq_y
=
M
OE
_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_K 64
#define M
MQ
_Y_Q4_K 128
#define M
OE
_X_Q4_K 64
#define M
OE
_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define M
MQ
_X_Q4_K 4
#define M
MQ
_Y_Q4_K 32
#define M
OE
_X_Q4_K 4
#define M
OE
_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
...
...
@@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_K
;
const
int
mmq_y
=
M
MQ
_Y_Q4_K
;
const
int
mmq_x
=
M
OE
_X_Q4_K
;
const
int
mmq_y
=
M
OE
_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
moe_q
<
scalar_t
,
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_K
;
const
int
mmq_y
=
M
MQ
_Y_Q4_K
;
const
int
mmq_x
=
M
OE
_X_Q4_K
;
const
int
mmq_y
=
M
OE
_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_K 64
#define M
MQ
_Y_Q5_K 128
#define M
OE
_X_Q5_K 64
#define M
OE
_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define M
MQ
_X_Q5_K 4
#define M
MQ
_Y_Q5_K 32
#define M
OE
_X_Q5_K 4
#define M
OE
_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
...
...
@@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_K
;
const
int
mmq_y
=
M
MQ
_Y_Q5_K
;
const
int
mmq_x
=
M
OE
_X_Q5_K
;
const
int
mmq_y
=
M
OE
_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
moe_q
<
scalar_t
,
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_K
;
const
int
mmq_y
=
M
MQ
_Y_Q5_K
;
const
int
mmq_x
=
M
OE
_X_Q5_K
;
const
int
mmq_y
=
M
OE
_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q6_K 64
#define M
MQ
_Y_Q6_K 128
#define M
OE
_X_Q6_K 64
#define M
OE
_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define M
MQ
_X_Q6_K 4
#define M
MQ
_Y_Q6_K 32
#define M
OE
_X_Q6_K 4
#define M
OE
_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
...
...
@@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q6_K
;
const
int
mmq_y
=
M
MQ
_Y_Q6_K
;
const
int
mmq_x
=
M
OE
_X_Q6_K
;
const
int
mmq_y
=
M
OE
_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
moe_q
<
scalar_t
,
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q6_K
;
const
int
mmq_y
=
M
MQ
_Y_Q6_K
;
const
int
mmq_x
=
M
OE
_X_Q6_K
;
const
int
mmq_y
=
M
OE
_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment