Unverified Commit e22ee1e7 authored by Szymon Ożóg's avatar Szymon Ożóg Committed by GitHub
Browse files

[Kernel] GGUF MoE kernel (#14613)


Signed-off-by: default avatarSzymonOzog <szymon.ozog@aleph-alpha.com>
parent e392d858
......@@ -151,6 +151,14 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
......
......@@ -12,6 +12,7 @@
#include "dequantize.cuh"
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
// Q8 gemv
template <typename scalar_t>
......@@ -59,10 +60,14 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
const int block_num_x =
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ky, 1);
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<scalar_t>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
constexpr int MAX_BLOCK_SIZE = 65535;
for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) {
const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off;
const dim3 num_blocks(block_num_x, num_blocks_y, 1);
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(
&x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded);
}
}
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
......@@ -263,3 +268,132 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
});
return Y;
}
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens) {
int col = X.sizes()[1];
int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
col, tokens, stream);
switch (type) {
case 2:
ggml_moe_q4_0_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 3:
ggml_moe_q4_1_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 6:
ggml_moe_q5_0_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 7:
ggml_moe_q5_1_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 8:
ggml_moe_q8_0_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 10:
ggml_moe_q2_K_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 11:
ggml_moe_q3_K_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 12:
ggml_moe_q4_K_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 13:
ggml_moe_q5_K_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
case 14:
ggml_moe_q6_K_q8_1_cuda(
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
(int*)expert_ids.data_ptr(),
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
break;
}
});
return Y;
}
int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) {
case 2:
return MMQ_X_Q4_0;
case 3:
return MMQ_X_Q4_1;
case 6:
return MMQ_X_Q5_0;
case 7:
return MMQ_X_Q5_1;
case 8:
return MMQ_X_Q8_0;
case 10:
return MMQ_X_Q2_K;
case 11:
return MMQ_X_Q3_K;
case 12:
return MMQ_X_Q4_K;
case 13:
return MMQ_X_Q5_K;
case 14:
return MMQ_X_Q6_K;
}
return 0;
}
This diff is collapsed.
......@@ -305,6 +305,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// moe kernel for GGML.
ops.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, Tensor "
"num_tokens_post_padded, "
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
......
......@@ -22,3 +22,16 @@ def test_ggml_opcheck(quant_type):
(qweight, x, quant_type, qweight.shape[0]))
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
(qweight, x, quant_type, qweight.shape[0]))
shape = [256, 1024, 336]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
sorted_token_ids = torch.arange(776, device='cuda')
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
num_tokens_post_padded = torch.tensor([1],
dtype=torch.int64,
device='cuda')
opcheck(torch.ops._C.ggml_moe_a8,
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
quant_type, qweight.shape[0], 1, x.shape[0]))
......@@ -8,9 +8,13 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download
import vllm._custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors(
......@@ -22,6 +26,15 @@ def get_gguf_sample_tensors(
return GGUFReader(sample_file).tensors
def get_gguf_MoE_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
DTYPES = [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
......@@ -132,3 +145,54 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", [512])
@pytest.mark.parametrize("top_k", [4, 8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
@torch.inference_mode()
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType, top_k: int):
current_platform.seed_everything(0)
H, E = 1024, 256
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda")
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
w13 = tensors[0]
w2 = tensors[1]
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
device="cuda").to(dtype)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
act = SiluAndMul()
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, act)
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
......@@ -448,6 +448,23 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
batch = X.size(0)
return torch.empty((batch, row), dtype=X.dtype, device=W.device)
@register_fake("_C::ggml_moe_a8")
def _ggml_moe_a8_fake(
X: torch.Tensor,
W: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
quant_type: int,
row: torch.SymInt,
top_k: torch.SymInt,
tokens: torch.SymInt,
) -> torch.Tensor:
tokens = X.size(0)
return torch.empty((tokens * top_k, row),
dtype=torch.float16,
device=W.device)
# cutlass
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
......@@ -1034,6 +1051,26 @@ def ggml_mul_mat_a8(
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
def ggml_moe_a8(
X: torch.Tensor,
W: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
quant_type: int,
row: int,
top_k: int,
tokens: int,
) -> torch.Tensor:
return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids,
num_tokens_post_padded, quant_type, row,
top_k, tokens)
def ggml_moe_get_block_size(quant_type: int) -> int:
return torch.ops._C.ggml_moe_get_block_size(quant_type)
# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
......
......@@ -8,7 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
......@@ -18,6 +20,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
......@@ -119,6 +123,59 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return y
def _fused_moe_gguf(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
act,
) -> torch.Tensor:
out_hidden_states = torch.empty_like(x)
if qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES:
num_tokens, _ = x.shape
E, N, _ = w1.shape
top_k = topk_ids.shape[1]
BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type)
sorted_token_ids, expert_ids, num_tokens_post_padded = \
moe_align_block_size(topk_ids, BLOCK_SIZE, E)
out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type, N, top_k,
num_tokens)
out = act(out)
out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids,
num_tokens_post_padded, qweight_type2,
w2.shape[1], 1, num_tokens * top_k)
out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_(
topk_weights.view(num_tokens, top_k, 1))
ops.moe_sum(out, out_hidden_states)
else:
logger.warning_once("There is no support for fast MoE kernel "
"for current quantization method. "
"Falling back to slow implementation. ")
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = w1[ii]
out = _fuse_mul_mat(inp, expert_up, qweight_type)
out = act(out)
expert_down = w2[ii]
current_state = _fuse_mul_mat(out, expert_down,
qweight_type2).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
out_hidden_states[tok] = current_hidden_state
return out_hidden_states
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
......@@ -285,27 +342,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]
out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)
expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, self.act)
class GGUFEmbeddingMethod(GGUFLinearMethod):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment