Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e22ee1e7
Unverified
Commit
e22ee1e7
authored
Mar 12, 2025
by
Szymon Ożóg
Committed by
GitHub
Mar 12, 2025
Browse files
[Kernel] GGUF MoE kernel (#14613)
Signed-off-by:
SzymonOzog
<
szymon.ozog@aleph-alpha.com
>
parent
e392d858
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1070 additions
and
25 deletions
+1070
-25
csrc/ops.h
csrc/ops.h
+8
-0
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+138
-4
csrc/quantization/gguf/moe.cuh
csrc/quantization/gguf/moe.cuh
+739
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+10
-0
tests/kernels/test_ggml.py
tests/kernels/test_ggml.py
+13
-0
tests/kernels/test_gguf.py
tests/kernels/test_gguf.py
+64
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+37
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+61
-21
No files found.
csrc/ops.h
View file @
e22ee1e7
...
...
@@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
e22ee1e7
...
...
@@ -12,6 +12,7 @@
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
// Q8 gemv
template
<
typename
scalar_t
>
...
...
@@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
const
int64_t
kx_padded
=
(
kx
+
512
-
1
)
/
512
*
512
;
const
int
block_num_x
=
(
kx_padded
+
CUDA_QUANTIZE_BLOCK_SIZE
-
1
)
/
CUDA_QUANTIZE_BLOCK_SIZE
;
const
dim3
num_blocks
(
block_num_x
,
ky
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<
scalar_t
>
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
x
,
vy
,
kx
,
kx_padded
);
constexpr
int
MAX_BLOCK_SIZE
=
65535
;
for
(
int
off
=
0
;
off
<
ky
;
off
+=
MAX_BLOCK_SIZE
)
{
const
int
num_blocks_y
=
std
::
min
(
ky
,
off
+
MAX_BLOCK_SIZE
)
-
off
;
const
dim3
num_blocks
(
block_num_x
,
num_blocks_y
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
&
x
[
off
*
kx
],
(
int32_t
*
)
vy
+
off
*
(
kx_padded
/
32
*
9
),
kx
,
kx_padded
);
}
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
...
...
@@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
});
return
Y
;
}
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
// input
torch
::
Tensor
W
,
// expert weights
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
tokens
*
top_k
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
tokens
,
padded
/
32
*
9
},
options
);
VLLM_DISPATCH_FLOATING_TYPES
(
X
.
scalar_type
(),
"ggml_moe_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
tokens
,
stream
);
switch
(
type
)
{
case
2
:
ggml_moe_q4_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
3
:
ggml_moe_q4_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
6
:
ggml_moe_q5_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
7
:
ggml_moe_q5_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
8
:
ggml_moe_q8_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
10
:
ggml_moe_q2_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
11
:
ggml_moe_q3_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
12
:
ggml_moe_q4_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
13
:
ggml_moe_q5_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
14
:
ggml_moe_q6_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
}
});
return
Y
;
}
int64_t
ggml_moe_get_block_size
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
MMQ_X_Q4_0
;
case
3
:
return
MMQ_X_Q4_1
;
case
6
:
return
MMQ_X_Q5_0
;
case
7
:
return
MMQ_X_Q5_1
;
case
8
:
return
MMQ_X_Q8_0
;
case
10
:
return
MMQ_X_Q2_K
;
case
11
:
return
MMQ_X_Q3_K
;
case
12
:
return
MMQ_X_Q4_K
;
case
13
:
return
MMQ_X_Q5_K
;
case
14
:
return
MMQ_X_Q6_K
;
}
return
0
;
}
csrc/quantization/gguf/moe.cuh
0 → 100644
View file @
e22ee1e7
#include <cstdint>
/* Adapted from ./csrc/quantization/gguf/mmq.cuh
based on ./vllm/model_executor/layers/fused_moe/fused_moe.py */
template
<
typename
scalar_t
,
int
qk
,
int
qr
,
int
qi
,
bool
need_sum
,
typename
block_q_t
,
int
mmq_x
,
int
mmq_y
,
int
nwarps
,
allocate_tiles_cuda_t
allocate_tiles
,
load_tiles_cuda_t
load_tiles
,
int
vdr
,
vec_dot_q_mul_mat_cuda_t
vec_dot
>
static
__device__
__forceinline__
void
moe_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
__restrict__
sorted_token_ids
,
const
int
*
__restrict__
expert_ids
,
const
int
*
__restrict__
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
const
int
blocks_per_warp
=
WARP_SIZE_GGUF
/
qi
;
const
int
ncols_dst
=
ncols_y
*
top_k
;
const
int
row_dst_0
=
blockIdx
.
x
*
mmq_y
;
const
int
&
row_x_0
=
row_dst_0
;
const
int
col_dst_0
=
blockIdx
.
y
*
mmq_x
;
int
token_offs
[
mmq_x
/
nwarps
];
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
token_offs
[
i
/
nwarps
]
=
sorted_token_ids
[
col_dst_0
+
threadIdx
.
y
+
i
];
}
const
int
exp_idx
=
expert_ids
[
blockIdx
.
y
];
if
(
exp_idx
>
255
||
exp_idx
<
0
)
return
;
if
(
blockIdx
.
y
*
mmq_x
>
num_tokens_post_padded
[
0
])
return
;
const
block_q_t
*
x
=
(
const
block_q_t
*
)((
char
*
)
vx
+
exp_idx
*
exp_stride
);
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)(
vy
);
int
*
tile_x_ql
=
nullptr
;
half2
*
tile_x_dm
=
nullptr
;
int
*
tile_x_qh
=
nullptr
;
int
*
tile_x_sc
=
nullptr
;
allocate_tiles
(
&
tile_x_ql
,
&
tile_x_dm
,
&
tile_x_qh
,
&
tile_x_sc
);
__shared__
int
tile_y_qs
[
mmq_x
*
WARP_SIZE_GGUF
];
__shared__
half2
tile_y_ds
[
mmq_x
*
WARP_SIZE_GGUF
/
QI8_1
];
float
sum
[
mmq_y
/
WARP_SIZE_GGUF
][
mmq_x
/
nwarps
]
=
{{
0.0
f
}};
for
(
int
ib0
=
0
;
ib0
<
blocks_per_row_x
;
ib0
+=
blocks_per_warp
)
{
load_tiles
(
x
+
row_x_0
*
blocks_per_row_x
+
ib0
,
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
threadIdx
.
y
,
nrows_x
-
row_x_0
-
1
,
threadIdx
.
x
,
blocks_per_row_x
);
const
int
n_per_r
=
((
qk
*
blocks_per_warp
)
/
qr
);
#pragma unroll
for
(
int
ir
=
0
;
ir
<
qr
&&
ib0
*
qk
+
ir
*
n_per_r
<
ncols_x
;
++
ir
)
{
const
int
kqs
=
ir
*
WARP_SIZE_GGUF
+
threadIdx
.
x
;
const
int
kbxd
=
kqs
/
QI8_1
;
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_x
;
i
+=
nwarps
)
{
const
int
col_y_eff
=
token_offs
[
i
/
nwarps
]
/
top_k
;
const
int
block_x
=
ib0
*
(
qk
/
QK8_1
)
+
kbxd
;
if
(
col_y_eff
<
ncols_y
&&
block_x
<
blocks_per_col_y
)
{
const
block_q8_1
*
by0
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
block_x
];
const
int
index_y
=
(
threadIdx
.
y
+
i
)
*
WARP_SIZE_GGUF
+
kqs
%
WARP_SIZE_GGUF
;
tile_y_qs
[
index_y
]
=
get_int_from_int8_aligned
(
by0
->
qs
,
threadIdx
.
x
%
QI8_1
);
}
}
if
(
threadIdx
.
x
<
n_per_r
/
QK8_1
)
{
const
int
kby
=
threadIdx
.
x
%
(
WARP_SIZE_GGUF
/
QI8_1
);
const
int
col_y_eff
=
token_offs
[
threadIdx
.
y
]
/
top_k
;
const
int
block_x
=
ib0
*
(
qk
/
QK8_1
)
+
ir
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
;
if
(
col_y_eff
<
ncols_y
&&
block_x
<
blocks_per_col_y
)
{
const
half2
*
dsi_src
=
&
y
[
col_y_eff
*
blocks_per_col_y
+
block_x
].
ds
;
half2
*
dsi_dst
=
&
tile_y_ds
[
threadIdx
.
y
*
(
WARP_SIZE_GGUF
/
QI8_1
)
+
kby
];
if
(
need_sum
)
{
*
dsi_dst
=
*
dsi_src
;
}
else
{
float
*
dfi_dst
=
(
float
*
)
dsi_dst
;
*
dfi_dst
=
__low2float
(
*
dsi_src
);
}
}
}
__syncthreads
();
// #pragma unroll // unrolling this loop causes too much register pressure
for
(
int
k
=
ir
*
WARP_SIZE_GGUF
/
qr
;
k
<
(
ir
+
1
)
*
WARP_SIZE_GGUF
/
qr
;
k
+=
vdr
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
]
+=
vec_dot
(
tile_x_ql
,
tile_x_dm
,
tile_x_qh
,
tile_x_sc
,
tile_y_qs
,
tile_y_ds
,
threadIdx
.
x
+
i
,
threadIdx
.
y
+
j
,
k
);
}
}
}
__syncthreads
();
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
mmq_x
;
j
+=
nwarps
)
{
const
int
col_dst
=
token_offs
[
j
/
nwarps
];
if
(
col_dst
>=
ncols_dst
)
{
return
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
mmq_y
;
i
+=
WARP_SIZE_GGUF
)
{
const
int
row_dst
=
row_dst_0
+
threadIdx
.
x
+
i
;
if
(
row_dst
>=
nrows_dst
)
{
continue
;
}
dst
[
col_dst
*
nrows_dst
+
row_dst
]
=
sum
[
i
/
WARP_SIZE_GGUF
][
j
/
nwarps
];
}
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_0 64
#define MMQ_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MMQ_X_Q4_0 4
#define MMQ_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_0
,
2
)
#endif
moe_q4_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q4_0
;
const
int
mmq_y
=
MMQ_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
moe_q
<
scalar_t
,
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_0
<
mmq_y
>
,
load_tiles_q4_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_0_Q8_1_MMQ
,
vec_dot_q4_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_0
;
int
mmq_y
=
MMQ_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_1 64
#define MMQ_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MMQ_X_Q4_1 4
#define MMQ_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_1
,
2
)
#endif
moe_q4_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q4_1
;
const
int
mmq_y
=
MMQ_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
moe_q
<
scalar_t
,
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_1
<
mmq_y
>
,
load_tiles_q4_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_1_Q8_1_MMQ
,
vec_dot_q4_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_1_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
MMQ_X_Q4_1
;
int
mmq_y
=
MMQ_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_0 64
#define MMQ_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MMQ_X_Q5_0 4
#define MMQ_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_0
,
2
)
#endif
moe_q5_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
moe_q
<
scalar_t
,
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_0
<
mmq_y
>
,
load_tiles_q5_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_0_Q8_1_MMQ
,
vec_dot_q5_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_0
;
const
int
mmq_y
=
MMQ_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_1 64
#define MMQ_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MMQ_X_Q5_1 4
#define MMQ_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_1
,
2
)
#endif
moe_q5_1
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
moe_q
<
scalar_t
,
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_1
<
mmq_y
>
,
load_tiles_q5_1
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_1_Q8_1_MMQ
,
vec_dot_q5_1_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_1_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_1
;
const
int
mmq_y
=
MMQ_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_1
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q8_0 64
#define MMQ_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MMQ_X_Q8_0 4
#define MMQ_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q8_0
,
2
)
#endif
moe_q8_0
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
moe_q
<
scalar_t
,
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q8_0
<
mmq_y
>
,
load_tiles_q8_0
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q8_0_Q8_1_MMQ
,
vec_dot_q8_0_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q8_0_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q8_0
;
const
int
mmq_y
=
MMQ_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q8_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q8_0
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q2_K 64
#define MMQ_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MMQ_X_Q2_K 4
#define MMQ_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q2_K
,
2
)
#endif
moe_q2_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
moe_q
<
scalar_t
,
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q2_K
<
mmq_y
>
,
load_tiles_q2_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q2_K_Q8_1_MMQ
,
vec_dot_q2_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q2_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q2_K
;
const
int
mmq_y
=
MMQ_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q2_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q2_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q3_K 64
#define MMQ_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MMQ_X_Q3_K 4
#define MMQ_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q3_K
,
2
)
#endif
moe_q3_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
moe_q
<
scalar_t
,
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q3_K
<
mmq_y
>
,
load_tiles_q3_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q3_K_Q8_1_MMQ
,
vec_dot_q3_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q3_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q3_K
;
const
int
mmq_y
=
MMQ_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q3_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q3_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_K 64
#define MMQ_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MMQ_X_Q4_K 4
#define MMQ_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q4_K
,
2
)
#endif
moe_q4_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
moe_q
<
scalar_t
,
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q4_K
<
mmq_y
>
,
load_tiles_q4_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q4_K_Q8_1_MMQ
,
vec_dot_q4_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q4_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q4_K
;
const
int
mmq_y
=
MMQ_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q4_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q4_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_K 64
#define MMQ_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MMQ_X_Q5_K 4
#define MMQ_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q5_K
,
2
)
#endif
moe_q5_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
moe_q
<
scalar_t
,
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q5_K
<
mmq_y
>
,
load_tiles_q5_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q5_K_Q8_1_MMQ
,
vec_dot_q5_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q5_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q5_K
;
const
int
mmq_y
=
MMQ_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q5_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q5_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
#if defined(USE_ROCM)
#define MMQ_X_Q6_K 64
#define MMQ_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MMQ_X_Q6_K 4
#define MMQ_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
template
<
typename
scalar_t
,
bool
need_check
>
static
__global__
void
#if defined(USE_ROCM)
__launch_bounds__
(
WARP_SIZE_GGUF
*
NWARPS_Q6_K
,
2
)
#endif
moe_q6_K
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
scalar_t
*
__restrict__
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
moe_q
<
scalar_t
,
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
allocate_tiles_q6_K
<
mmq_y
>
,
load_tiles_q6_K
<
mmq_y
,
nwarps
,
need_check
>
,
VDR_Q6_K_Q8_1_MMQ
,
vec_dot_q6_K_q8_1_mul_mat
>
(
vx
,
vy
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
template
<
typename
scalar_t
>
static
void
ggml_moe_q6_K_q8_1_cuda
(
const
void
*
inp
,
const
void
*
w
,
scalar_t
*
dst
,
const
int
*
sorted_token_ids
,
const
int
*
expert_ids
,
const
int
*
num_tokens_post_padded
,
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
MMQ_X_Q6_K
;
const
int
mmq_y
=
MMQ_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
const
int
block_num_y
=
(
tokens_post_padded
)
/
mmq_x
;
const
dim3
block_nums
(
block_num_x
,
block_num_y
,
1
);
const
dim3
block_dims
(
WARP_SIZE_GGUF
,
nwarps
,
1
);
if
(
nrows_x
%
mmq_y
==
0
)
{
constexpr
bool
need_check
=
false
;
moe_q6_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
else
{
constexpr
bool
need_check
=
true
;
moe_q6_K
<
scalar_t
,
need_check
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
w
,
inp
,
dst
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
exp_stride
,
ncols_x
,
nrows_x
,
ncols_y
,
nrows_y
,
nrows_dst
,
top_k
);
}
}
csrc/torch_bindings.cpp
View file @
e22ee1e7
...
...
@@ -305,6 +305,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
// moe kernel for GGML.
ops
.
def
(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"
);
ops
.
impl
(
"ggml_moe_a8"
,
torch
::
kCUDA
,
&
ggml_moe_a8
);
ops
.
def
(
"ggml_moe_get_block_size"
,
&
ggml_moe_get_block_size
);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
...
...
tests/kernels/test_ggml.py
View file @
e22ee1e7
...
...
@@ -22,3 +22,16 @@ def test_ggml_opcheck(quant_type):
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]))
opcheck
(
torch
.
ops
.
_C
.
ggml_mul_mat_vec_a8
,
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]))
shape
=
[
256
,
1024
,
336
]
qweight
=
torch
.
randint
(
0
,
100
,
shape
,
device
=
'cuda'
,
dtype
=
torch
.
uint8
)
x
=
torch
.
rand
((
1
,
1024
),
device
=
'cuda'
,
dtype
=
torch
.
float16
)
sorted_token_ids
=
torch
.
arange
(
776
,
device
=
'cuda'
)
expert_ids
=
torch
.
randint
(
0
,
256
,
(
194
,
),
device
=
'cuda'
)
num_tokens_post_padded
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
,
device
=
'cuda'
)
opcheck
(
torch
.
ops
.
_C
.
ggml_moe_a8
,
(
x
,
qweight
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
quant_type
,
qweight
.
shape
[
0
],
1
,
x
.
shape
[
0
]))
tests/kernels/test_gguf.py
View file @
e22ee1e7
...
...
@@ -8,9 +8,13 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from
huggingface_hub
import
snapshot_download
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.gguf
import
_fused_moe_gguf
from
vllm.platforms
import
current_platform
GGUF_SAMPLE
=
snapshot_download
(
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE_MOE
=
snapshot_download
(
"SzymonOzog/test-gguf-moe-sample"
)
def
get_gguf_sample_tensors
(
...
...
@@ -22,6 +26,15 @@ def get_gguf_sample_tensors(
return
GGUFReader
(
sample_file
).
tensors
def
get_gguf_MoE_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE_MOE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
...
...
@@ -132,3 +145,54 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output
,
atol
=
atols
[
dtype
],
rtol
=
rtols
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
])
@
torch
.
inference_mode
()
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
current_platform
.
seed_everything
(
0
)
H
,
E
=
1024
,
256
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
w13
=
tensors
[
0
]
w2
=
tensors
[
1
]
w13_dequant
=
torch
.
tensor
(
dequantize
(
w13
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
w2_dequant
=
torch
.
tensor
(
dequantize
(
w2
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
act
=
SiluAndMul
()
output
=
_fused_moe_gguf
(
x
,
torch
.
tensor
(
w13
.
data
,
device
=
"cuda"
),
torch
.
tensor
(
w2
.
data
,
device
=
"cuda"
),
topk_weights
,
topk_ids
,
quant_type
,
quant_type
,
act
)
ref_output
=
fused_experts
(
x
,
w13_dequant
,
w2_dequant
,
topk_weights
,
topk_ids
).
reshape
(
output
.
shape
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
vllm/_custom_ops.py
View file @
e22ee1e7
...
...
@@ -448,6 +448,23 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
X
.
dtype
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_moe_a8"
)
def
_ggml_moe_a8_fake
(
X
:
torch
.
Tensor
,
W
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
top_k
:
torch
.
SymInt
,
tokens
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
tokens
=
X
.
size
(
0
)
return
torch
.
empty
((
tokens
*
top_k
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
# cutlass
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
@@ -1034,6 +1051,26 @@ def ggml_mul_mat_a8(
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
def
ggml_moe_a8
(
X
:
torch
.
Tensor
,
W
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
top_k
:
int
,
tokens
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_moe_a8
(
X
,
W
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
quant_type
,
row
,
top_k
,
tokens
)
def
ggml_moe_get_block_size
(
quant_type
:
int
)
->
int
:
return
torch
.
ops
.
_C
.
ggml_moe_get_block_size
(
quant_type
)
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
e22ee1e7
...
...
@@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.fused_moe
import
moe_align_block_size
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
...
...
@@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
GGUFConfig
(
QuantizationConfig
):
"""Config class for GGUF."""
...
...
@@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return
y
def
_fused_moe_gguf
(
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
qweight_type
:
int
,
qweight_type2
:
int
,
act
,
)
->
torch
.
Tensor
:
out_hidden_states
=
torch
.
empty_like
(
x
)
if
qweight_type2
in
MMQ_QUANT_TYPES
and
qweight_type
in
MMQ_QUANT_TYPES
:
num_tokens
,
_
=
x
.
shape
E
,
N
,
_
=
w1
.
shape
top_k
=
topk_ids
.
shape
[
1
]
BLOCK_SIZE
=
ops
.
ggml_moe_get_block_size
(
qweight_type
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
moe_align_block_size
(
topk_ids
,
BLOCK_SIZE
,
E
)
out
=
ops
.
ggml_moe_a8
(
x
,
w1
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
qweight_type
,
N
,
top_k
,
num_tokens
)
out
=
act
(
out
)
out
=
ops
.
ggml_moe_a8
(
out
,
w2
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
qweight_type2
,
w2
.
shape
[
1
],
1
,
num_tokens
*
top_k
)
out
=
out
.
reshape
(
num_tokens
,
top_k
,
w2
.
shape
[
1
]).
mul_
(
topk_weights
.
view
(
num_tokens
,
top_k
,
1
))
ops
.
moe_sum
(
out
,
out_hidden_states
)
else
:
logger
.
warning_once
(
"There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. "
)
for
tok
,
(
w
,
idx
)
in
enumerate
(
zip
(
topk_weights
,
topk_ids
)):
inp
=
x
[
tok
].
reshape
((
1
,
)
+
x
.
shape
[
1
:])
current_hidden_state
=
None
for
ww
,
ii
in
zip
(
w
,
idx
):
expert_up
=
w1
[
ii
]
out
=
_fuse_mul_mat
(
inp
,
expert_up
,
qweight_type
)
out
=
act
(
out
)
expert_down
=
w2
[
ii
]
current_state
=
_fuse_mul_mat
(
out
,
expert_down
,
qweight_type2
).
mul_
(
ww
)
if
current_hidden_state
is
None
:
current_hidden_state
=
current_state
else
:
current_hidden_state
.
add_
(
current_state
)
out_hidden_states
[
tok
]
=
current_hidden_state
return
out_hidden_states
class
GGUFLinearMethod
(
LinearMethodBase
):
"""Linear method for GGUF.
...
...
@@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
final_hidden_states
=
torch
.
empty_like
(
x
)
for
tok
,
(
w
,
idx
)
in
enumerate
(
zip
(
topk_weights
,
topk_ids
)):
inp
=
x
[
tok
].
reshape
((
1
,
)
+
x
.
shape
[
1
:])
current_hidden_state
=
None
for
ww
,
ii
in
zip
(
w
,
idx
):
expert_up
=
layer
.
w13_qweight
[
ii
]
out
=
_fuse_mul_mat
(
inp
,
expert_up
,
layer
.
w13_qweight_type
.
weight_type
)
out
=
self
.
act
(
out
)
expert_down
=
layer
.
w2_qweight
[
ii
]
current_state
=
_fuse_mul_mat
(
out
,
expert_down
,
layer
.
w2_qweight_type
.
weight_type
).
mul_
(
ww
)
if
current_hidden_state
is
None
:
current_hidden_state
=
current_state
else
:
current_hidden_state
.
add_
(
current_state
)
final_hidden_states
[
tok
]
=
current_hidden_state
return
final_hidden_states
return
_fused_moe_gguf
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
,
topk_ids
,
layer
.
w13_qweight_type
.
weight_type
,
layer
.
w2_qweight_type
.
weight_type
,
self
.
act
)
class
GGUFEmbeddingMethod
(
GGUFLinearMethod
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment