Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e22ee1e7
Unverified
Commit
e22ee1e7
authored
Mar 12, 2025
by
Szymon Ożóg
Committed by
GitHub
Mar 12, 2025
Browse files
[Kernel] GGUF MoE kernel (#14613)
Signed-off-by:
SzymonOzog
<
szymon.ozog@aleph-alpha.com
>
parent
e392d858
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1070 additions
and
25 deletions
+1070
-25
csrc/ops.h
csrc/ops.h
+8
-0
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+138
-4
csrc/quantization/gguf/moe.cuh
csrc/quantization/gguf/moe.cuh
+739
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+10
-0
tests/kernels/test_ggml.py
tests/kernels/test_ggml.py
+13
-0
tests/kernels/test_gguf.py
tests/kernels/test_gguf.py
+64
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+37
-0
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+61
-21
No files found.
csrc/ops.h
View file @
e22ee1e7
...
...
@@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
torch
::
Tensor
W
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
);
int64_t
ggml_moe_get_block_size
(
int64_t
type
);
#ifndef USE_ROCM
void
cutlass_scaled_fp4_mm
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
e22ee1e7
...
...
@@ -12,6 +12,7 @@
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
// Q8 gemv
template
<
typename
scalar_t
>
...
...
@@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
const
int64_t
kx_padded
=
(
kx
+
512
-
1
)
/
512
*
512
;
const
int
block_num_x
=
(
kx_padded
+
CUDA_QUANTIZE_BLOCK_SIZE
-
1
)
/
CUDA_QUANTIZE_BLOCK_SIZE
;
const
dim3
num_blocks
(
block_num_x
,
ky
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<
scalar_t
>
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
x
,
vy
,
kx
,
kx_padded
);
constexpr
int
MAX_BLOCK_SIZE
=
65535
;
for
(
int
off
=
0
;
off
<
ky
;
off
+=
MAX_BLOCK_SIZE
)
{
const
int
num_blocks_y
=
std
::
min
(
ky
,
off
+
MAX_BLOCK_SIZE
)
-
off
;
const
dim3
num_blocks
(
block_num_x
,
num_blocks_y
,
1
);
const
dim3
block_size
(
CUDA_DEQUANTIZE_BLOCK_SIZE
,
1
,
1
);
quantize_q8_1
<<<
num_blocks
,
block_size
,
0
,
stream
>>>
(
&
x
[
off
*
kx
],
(
int32_t
*
)
vy
+
off
*
(
kx_padded
/
32
*
9
),
kx
,
kx_padded
);
}
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
...
...
@@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
});
return
Y
;
}
torch
::
Tensor
ggml_moe_a8
(
torch
::
Tensor
X
,
// input
torch
::
Tensor
W
,
// expert weights
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
expert_ids
,
torch
::
Tensor
num_tokens_post_padded
,
int64_t
type
,
int64_t
row
,
int64_t
top_k
,
int64_t
tokens
)
{
int
col
=
X
.
sizes
()[
1
];
int
padded
=
(
col
+
512
-
1
)
/
512
*
512
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
X
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
X
.
dtype
()).
device
(
W
.
device
());
at
::
Tensor
Y
=
torch
::
empty
({
tokens
*
top_k
,
row
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
).
device
(
W
.
device
());
at
::
Tensor
quant_X
=
torch
::
empty
({
tokens
,
padded
/
32
*
9
},
options
);
VLLM_DISPATCH_FLOATING_TYPES
(
X
.
scalar_type
(),
"ggml_moe_a8"
,
[
&
]
{
quantize_row_q8_1_cuda
((
scalar_t
*
)
X
.
data_ptr
(),
(
void
*
)
quant_X
.
data_ptr
(),
col
,
tokens
,
stream
);
switch
(
type
)
{
case
2
:
ggml_moe_q4_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
3
:
ggml_moe_q4_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
6
:
ggml_moe_q5_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
7
:
ggml_moe_q5_1_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
8
:
ggml_moe_q8_0_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
10
:
ggml_moe_q2_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
11
:
ggml_moe_q3_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
12
:
ggml_moe_q4_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
13
:
ggml_moe_q5_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
case
14
:
ggml_moe_q6_K_q8_1_cuda
(
(
void
*
)
quant_X
.
data_ptr
(),
(
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
Y
.
data_ptr
(),
(
int
*
)
sorted_token_ids
.
data_ptr
(),
(
int
*
)
expert_ids
.
data_ptr
(),
(
int
*
)
num_tokens_post_padded
.
data_ptr
(),
W
.
stride
(
0
),
col
,
row
,
tokens
,
padded
,
row
,
top_k
,
sorted_token_ids
.
sizes
()[
0
],
stream
);
break
;
}
});
return
Y
;
}
int64_t
ggml_moe_get_block_size
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
MMQ_X_Q4_0
;
case
3
:
return
MMQ_X_Q4_1
;
case
6
:
return
MMQ_X_Q5_0
;
case
7
:
return
MMQ_X_Q5_1
;
case
8
:
return
MMQ_X_Q8_0
;
case
10
:
return
MMQ_X_Q2_K
;
case
11
:
return
MMQ_X_Q3_K
;
case
12
:
return
MMQ_X_Q4_K
;
case
13
:
return
MMQ_X_Q5_K
;
case
14
:
return
MMQ_X_Q6_K
;
}
return
0
;
}
csrc/quantization/gguf/moe.cuh
0 → 100644
View file @
e22ee1e7
This diff is collapsed.
Click to expand it.
csrc/torch_bindings.cpp
View file @
e22ee1e7
...
...
@@ -305,6 +305,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"
);
ops
.
impl
(
"ggml_mul_mat_a8"
,
torch
::
kCUDA
,
&
ggml_mul_mat_a8
);
// moe kernel for GGML.
ops
.
def
(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"
);
ops
.
impl
(
"ggml_moe_a8"
,
torch
::
kCUDA
,
&
ggml_moe_a8
);
ops
.
def
(
"ggml_moe_get_block_size"
,
&
ggml_moe_get_block_size
);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops
.
def
(
...
...
tests/kernels/test_ggml.py
View file @
e22ee1e7
...
...
@@ -22,3 +22,16 @@ def test_ggml_opcheck(quant_type):
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]))
opcheck
(
torch
.
ops
.
_C
.
ggml_mul_mat_vec_a8
,
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]))
shape
=
[
256
,
1024
,
336
]
qweight
=
torch
.
randint
(
0
,
100
,
shape
,
device
=
'cuda'
,
dtype
=
torch
.
uint8
)
x
=
torch
.
rand
((
1
,
1024
),
device
=
'cuda'
,
dtype
=
torch
.
float16
)
sorted_token_ids
=
torch
.
arange
(
776
,
device
=
'cuda'
)
expert_ids
=
torch
.
randint
(
0
,
256
,
(
194
,
),
device
=
'cuda'
)
num_tokens_post_padded
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
,
device
=
'cuda'
)
opcheck
(
torch
.
ops
.
_C
.
ggml_moe_a8
,
(
x
,
qweight
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
quant_type
,
qweight
.
shape
[
0
],
1
,
x
.
shape
[
0
]))
tests/kernels/test_gguf.py
View file @
e22ee1e7
...
...
@@ -8,9 +8,13 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from
huggingface_hub
import
snapshot_download
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.gguf
import
_fused_moe_gguf
from
vllm.platforms
import
current_platform
GGUF_SAMPLE
=
snapshot_download
(
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE_MOE
=
snapshot_download
(
"SzymonOzog/test-gguf-moe-sample"
)
def
get_gguf_sample_tensors
(
...
...
@@ -22,6 +26,15 @@ def get_gguf_sample_tensors(
return
GGUFReader
(
sample_file
).
tensors
def
get_gguf_MoE_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE_MOE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
...
...
@@ -132,3 +145,54 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output
,
atol
=
atols
[
dtype
],
rtol
=
rtols
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
])
@
torch
.
inference_mode
()
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
current_platform
.
seed_everything
(
0
)
H
,
E
=
1024
,
256
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
w13
=
tensors
[
0
]
w2
=
tensors
[
1
]
w13_dequant
=
torch
.
tensor
(
dequantize
(
w13
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
w2_dequant
=
torch
.
tensor
(
dequantize
(
w2
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
act
=
SiluAndMul
()
output
=
_fused_moe_gguf
(
x
,
torch
.
tensor
(
w13
.
data
,
device
=
"cuda"
),
torch
.
tensor
(
w2
.
data
,
device
=
"cuda"
),
topk_weights
,
topk_ids
,
quant_type
,
quant_type
,
act
)
ref_output
=
fused_experts
(
x
,
w13_dequant
,
w2_dequant
,
topk_weights
,
topk_ids
).
reshape
(
output
.
shape
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
vllm/_custom_ops.py
View file @
e22ee1e7
...
...
@@ -448,6 +448,23 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
X
.
dtype
,
device
=
W
.
device
)
@
register_fake
(
"_C::ggml_moe_a8"
)
def
_ggml_moe_a8_fake
(
X
:
torch
.
Tensor
,
W
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
torch
.
SymInt
,
top_k
:
torch
.
SymInt
,
tokens
:
torch
.
SymInt
,
)
->
torch
.
Tensor
:
tokens
=
X
.
size
(
0
)
return
torch
.
empty
((
tokens
*
top_k
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
# cutlass
def
cutlass_scaled_fp4_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
@@ -1034,6 +1051,26 @@ def ggml_mul_mat_a8(
return
torch
.
ops
.
_C
.
ggml_mul_mat_a8
(
W
,
X
,
quant_type
,
row
)
def
ggml_moe_a8
(
X
:
torch
.
Tensor
,
W
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
quant_type
:
int
,
row
:
int
,
top_k
:
int
,
tokens
:
int
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
ggml_moe_a8
(
X
,
W
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
quant_type
,
row
,
top_k
,
tokens
)
def
ggml_moe_get_block_size
(
quant_type
:
int
)
->
int
:
return
torch
.
ops
.
_C
.
ggml_moe_get_block_size
(
quant_type
)
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/layers/quantization/gguf.py
View file @
e22ee1e7
...
...
@@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe.fused_moe
import
moe_align_block_size
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
...
...
@@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
GGUFConfig
(
QuantizationConfig
):
"""Config class for GGUF."""
...
...
@@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return
y
def
_fused_moe_gguf
(
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
qweight_type
:
int
,
qweight_type2
:
int
,
act
,
)
->
torch
.
Tensor
:
out_hidden_states
=
torch
.
empty_like
(
x
)
if
qweight_type2
in
MMQ_QUANT_TYPES
and
qweight_type
in
MMQ_QUANT_TYPES
:
num_tokens
,
_
=
x
.
shape
E
,
N
,
_
=
w1
.
shape
top_k
=
topk_ids
.
shape
[
1
]
BLOCK_SIZE
=
ops
.
ggml_moe_get_block_size
(
qweight_type
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
\
moe_align_block_size
(
topk_ids
,
BLOCK_SIZE
,
E
)
out
=
ops
.
ggml_moe_a8
(
x
,
w1
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
qweight_type
,
N
,
top_k
,
num_tokens
)
out
=
act
(
out
)
out
=
ops
.
ggml_moe_a8
(
out
,
w2
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
qweight_type2
,
w2
.
shape
[
1
],
1
,
num_tokens
*
top_k
)
out
=
out
.
reshape
(
num_tokens
,
top_k
,
w2
.
shape
[
1
]).
mul_
(
topk_weights
.
view
(
num_tokens
,
top_k
,
1
))
ops
.
moe_sum
(
out
,
out_hidden_states
)
else
:
logger
.
warning_once
(
"There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. "
)
for
tok
,
(
w
,
idx
)
in
enumerate
(
zip
(
topk_weights
,
topk_ids
)):
inp
=
x
[
tok
].
reshape
((
1
,
)
+
x
.
shape
[
1
:])
current_hidden_state
=
None
for
ww
,
ii
in
zip
(
w
,
idx
):
expert_up
=
w1
[
ii
]
out
=
_fuse_mul_mat
(
inp
,
expert_up
,
qweight_type
)
out
=
act
(
out
)
expert_down
=
w2
[
ii
]
current_state
=
_fuse_mul_mat
(
out
,
expert_down
,
qweight_type2
).
mul_
(
ww
)
if
current_hidden_state
is
None
:
current_hidden_state
=
current_state
else
:
current_hidden_state
.
add_
(
current_state
)
out_hidden_states
[
tok
]
=
current_hidden_state
return
out_hidden_states
class
GGUFLinearMethod
(
LinearMethodBase
):
"""Linear method for GGUF.
...
...
@@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
final_hidden_states
=
torch
.
empty_like
(
x
)
for
tok
,
(
w
,
idx
)
in
enumerate
(
zip
(
topk_weights
,
topk_ids
)):
inp
=
x
[
tok
].
reshape
((
1
,
)
+
x
.
shape
[
1
:])
current_hidden_state
=
None
for
ww
,
ii
in
zip
(
w
,
idx
):
expert_up
=
layer
.
w13_qweight
[
ii
]
out
=
_fuse_mul_mat
(
inp
,
expert_up
,
layer
.
w13_qweight_type
.
weight_type
)
out
=
self
.
act
(
out
)
expert_down
=
layer
.
w2_qweight
[
ii
]
current_state
=
_fuse_mul_mat
(
out
,
expert_down
,
layer
.
w2_qweight_type
.
weight_type
).
mul_
(
ww
)
if
current_hidden_state
is
None
:
current_hidden_state
=
current_state
else
:
current_hidden_state
.
add_
(
current_state
)
final_hidden_states
[
tok
]
=
current_hidden_state
return
final_hidden_states
return
_fused_moe_gguf
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
,
topk_ids
,
layer
.
w13_qweight_type
.
weight_type
,
layer
.
w2_qweight_type
.
weight_type
,
self
.
act
)
class
GGUFEmbeddingMethod
(
GGUFLinearMethod
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment