Unverified Commit 05d68643 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)


Co-authored-by: default avatarDipika <dipikasikka1@gmail.com>
Co-authored-by: default avatarDipika Sikka <ds3822@columbia.edu>
parent 0dcc8cbe
...@@ -433,6 +433,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -433,6 +433,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_moe_ops.cu") "csrc/moe/marlin_moe_ops.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
......
This diff is collapsed.
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8( ...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
......
...@@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8( ...@@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe } // namespace marlin_moe
...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128( ...@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks) { int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) { if (false) {
} }
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
......
...@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128( ...@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks, bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
bool replicate_input, bool apply_weights, int m_block, int max_par, int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int cfg_max_m_blocks); int m_block, int max_par, int cfg_max_m_blocks);
} }
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "core/registration.h" #include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T> template <typename T>
inline std::string str(T x) { inline std::string str(T x) {
...@@ -157,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = { ...@@ -157,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{128, 64, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X {64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N {64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
}; };
thread_config_t large_batch_thread_configs[] = { thread_config_t large_batch_thread_configs[] = {
...@@ -167,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = { ...@@ -167,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{128, 128, 256}, // Reduce N 2X, increase K 2X {128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K {64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X {128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
}; };
int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
...@@ -312,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, ...@@ -312,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}}; return exec_config_t{0, {-1, -1, -1}};
} }
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ else if (KERNEL_FUNCTION( \
has_act_order, group_blocks, num_threads, blocks, \ q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ group_blocks, num_threads, blocks, max_shared_mem, stream, \
sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
expert_offsets_ptr, num_groups, expert_idx, \ zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
locks, replicate_input, apply_weights, m_block, \ replicate_input, apply_weights, m_block, max_par, \
max_par, exec_cfg.max_m_blocks)) { \ exec_cfg.max_m_blocks)) { \
} }
void marlin_mm_moe(const void* A, const void* B, void* C, void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights, const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx, const void* topk_ids, const void* s, void* zp,
const void* perm, void* a_tmp, void* expert_offsets, const void* g_idx, const void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, void* workspace, void* expert_offsets, int prob_m, int prob_n, int prob_k,
vllm::ScalarType const& q_type, bool has_act_order, void* workspace, vllm::ScalarType const& q_type,
bool is_k_full, int num_groups, int group_size, bool has_act_order, bool is_k_full, bool has_zp,
int num_experts, int topk, int moe_block_size, int dev, int num_groups, int group_size, int num_experts, int topk,
cudaStream_t stream, int thread_k, int thread_n, int sms, int moe_block_size, int dev, cudaStream_t stream,
int max_par, bool replicate_input, bool apply_weights) { int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
...@@ -436,6 +440,8 @@ void marlin_mm_moe(const void* A, const void* B, void* C, ...@@ -436,6 +440,8 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
const float* topk_weights_ptr = (const float*)topk_weights; const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids; const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace; int* locks = (int*)workspace;
...@@ -456,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, ...@@ -456,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
} }
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" + str(prob_n) + ", " + str(prob_k) + "]" +
...@@ -475,13 +482,21 @@ torch::Tensor marlin_gemm_moe( ...@@ -475,13 +482,21 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& b_zeros, const torch::Tensor& g_idx,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, const torch::Tensor& perm, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t num_experts, int64_t topk, int64_t moe_block_size, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
bool replicate_input, bool apply_weights) { int64_t moe_block_size, bool replicate_input, bool apply_weights) {
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, bool has_zp = b_zeros.size(1) != 0;
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); if (has_zp) {
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits(); int pack_factor = 32 / b_q_type->size_bits();
...@@ -543,14 +558,27 @@ torch::Tensor marlin_gemm_moe( ...@@ -543,14 +558,27 @@ torch::Tensor marlin_gemm_moe(
} }
} }
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe( marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, num_experts, topk, moe_block_size, dev,
thread_n, sms, max_par, replicate_input, apply_weights); at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c; return c;
} }
......
...@@ -12,7 +12,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -12,7 +12,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)" "int moe_block_size, bool replicate_input, bool apply_weights)"
......
...@@ -2260,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -2260,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups); " is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1), "b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor); " is not size_n / pack_factor = ", size_n / pack_factor);
} }
......
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(
a,
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
...@@ -2,16 +2,14 @@ ...@@ -2,16 +2,14 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
from typing import List
import pytest import pytest
import torch import torch
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe) fused_marlin_moe, single_marlin_moe)
...@@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types ...@@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
...@@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -127,20 +94,10 @@ def test_mixtral_moe(dtype: torch.dtype):
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64]) @pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
...@@ -159,9 +116,6 @@ def test_fused_marlin_moe( ...@@ -159,9 +116,6 @@ def test_fused_marlin_moe(
): ):
seed_everything(7) seed_everything(7)
if topk > e:
return
# Filter act_order # Filter act_order
if act_order: if act_order:
if group_size == -1: if group_size == -1:
...@@ -241,15 +195,15 @@ def test_fused_marlin_moe( ...@@ -241,15 +195,15 @@ def test_fused_marlin_moe(
a, a,
qweight1, qweight1,
qweight2, qweight2,
scales1,
scales2,
score, score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=scales1, g_idx1=g_idx1,
w2_scale=scales2, g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
...@@ -280,9 +234,13 @@ def test_fused_marlin_moe( ...@@ -280,9 +234,13 @@ def test_fused_marlin_moe(
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe, opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids, (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False)) 2 * n, k, True, e, topk, block_size_m, True, False))
...@@ -291,7 +249,7 @@ def test_fused_marlin_moe( ...@@ -291,7 +249,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512]) @pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64]) @pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
...@@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply( ...@@ -308,8 +266,6 @@ def test_single_marlin_moe_multiply(
num_bits: int, num_bits: int,
is_k_full: bool, is_k_full: bool,
): ):
if topk > e:
return
# Filter act_order # Filter act_order
if act_order: if act_order:
...@@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply( ...@@ -355,13 +311,14 @@ def test_single_marlin_moe_multiply(
qweight, qweight,
scales, scales,
score, score,
g_idx,
sort_indices,
topk, topk,
renormalize=False, renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
num_bits=num_bits, num_bits=num_bits,
is_k_full=is_k_full, is_k_full=is_k_full,
) )
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 assert compute_max_diff(marlin_output, torch_output) < 1e-2
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad) make_tensor_with_pad)
...@@ -974,6 +975,50 @@ def fp8_allclose( ...@@ -974,6 +975,50 @@ def fp8_allclose(
equal_nan=equal_nan)).item()) equal_nan=equal_nan)).item())
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
# A special version of op check that has a restricted default set of test_utils # A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types. # and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
......
...@@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize ...@@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
\ No newline at end of file
#!/bin/bash #!/bin/bash
SUCCESS=0 SUCCESS=0
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" while getopts "c:" OPT; do
case ${OPT} in
c )
CONFIG="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do do
......
...@@ -568,6 +568,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -568,6 +568,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output return output
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output
def gptq_marlin_gemm(a: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_scales: torch.Tensor,
...@@ -828,11 +842,12 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): ...@@ -828,11 +842,12 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
sorted_ids: torch.Tensor, sorted_ids: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType, perm: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int, size_k: int, b_q_type: ScalarType, size_m: int, size_n: int,
is_k_full: bool, num_experts: int, topk: int, size_k: int, is_k_full: bool, num_experts: int,
moe_block_size: int, replicate_input: bool, topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor: apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n), return torch.empty((size_m, topk, size_n),
dtype=a.dtype, dtype=a.dtype,
......
...@@ -10,15 +10,24 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -10,15 +10,24 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe( def single_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
...@@ -34,10 +43,12 @@ def single_marlin_moe( ...@@ -34,10 +43,12 @@ def single_marlin_moe(
- scales (torch.Tensor): The quantization scales. - scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx (torch.Tensor): The act_order indices. - g_idx (Optional[torch.Tensor]): Optional act_order indices.
- perm (torch.Tensor): The act_order input permutation. - sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
...@@ -79,16 +90,34 @@ def single_marlin_moe( ...@@ -79,16 +90,34 @@ def single_marlin_moe(
max_workspace_size = (N // 64) * 16 max_workspace_size = (N // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device=hidden_states.device,
requires_grad=False)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if g_idx is None:
g_idx = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 if sort_indices is None:
if num_bits == 4 else scalar_types.uint8b128) sort_indices = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
block_size_m, True, False) is_k_full, E, topk, block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
...@@ -97,16 +126,18 @@ def fused_marlin_moe( ...@@ -97,16 +126,18 @@ def fused_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
perm1: torch.Tensor,
perm2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8, num_bits: int = 8,
is_k_full: bool = True, is_k_full: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -118,21 +149,22 @@ def fused_marlin_moe( ...@@ -118,21 +149,22 @@ def fused_marlin_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation - gating_output (torch.Tensor): The output of the gating operation
(before softmax). (before softmax).
- g_idx1 (torch.Tensor): The first set of act_order indices. - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (torch.Tensor): The second set of act_order indices. - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- perm1 (torch.Tensor): The first act_order input permutation. - sort_indices1 (Optional[torch.Tensor]): The first act_order input
- perm2 (torch.Tensor): The second act_order input permutation. permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights. - topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements. - topk_ids (torch.Tensor): Indices of topk-k elements.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
w1. - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- num_bits (bool): The number of bits in expert weights quantization. - num_bits (bool): The number of bits in expert weights quantization.
Returns: Returns:
...@@ -152,6 +184,20 @@ def fused_marlin_moe( ...@@ -152,6 +184,20 @@ def fused_marlin_moe(
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8] assert num_bits in [4, 8]
has_no_act_order = (g_idx1 is None and g_idx2 is None
and sort_indices1 is None and sort_indices2 is None)
has_all_act_order = (g_idx1 is not None and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices "
"must be all not None or must be all None")
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, ("zero points must be both not None or "
"must be both None")
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
N = w2.shape[1] * 16 N = w2.shape[1] * 16
...@@ -172,14 +218,42 @@ def fused_marlin_moe( ...@@ -172,14 +218,42 @@ def fused_marlin_moe(
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8 if has_no_zp:
if num_bits == 4 else scalar_types.uint8b128) w1_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
w2_zeros = torch.empty((0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False)
if has_no_act_order:
g_idx1 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
g_idx2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices1 = torch.empty((0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
sort_indices2 = torch.empty((0, 0),
dtype=torch.int32,
device=hidden_states.device,
requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
...@@ -194,10 +268,11 @@ def fused_marlin_moe( ...@@ -194,10 +268,11 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
w1_zeros,
g_idx1, g_idx1,
perm1, sort_indices1,
workspace, workspace,
scalar_type, scalar_type1,
M, M,
2 * N, 2 * N,
K, K,
...@@ -218,10 +293,11 @@ def fused_marlin_moe( ...@@ -218,10 +293,11 @@ def fused_marlin_moe(
topk_weights, topk_weights,
topk_ids, topk_ids,
w2_scale, w2_scale,
w2_zeros,
g_idx2, g_idx2,
perm2, sort_indices2,
workspace, workspace,
scalar_type, scalar_type2,
M, M,
K, K,
N, N,
......
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
...@@ -35,12 +40,13 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -35,12 +40,13 @@ class AWQMarlinConfig(QuantizationConfig):
self.group_size = group_size self.group_size = group_size
self.has_zp = has_zp self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
if weight_bits not in self.TYPE_MAP: if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {weight_bits}. " raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
f"Supported num_bits = {self.TYPE_MAP.keys()}") f"Supported num_bits = {self.TYPE_MAP.keys()}")
self.quant_type = self.TYPE_MAP[weight_bits] self.quant_type = self.TYPE_MAP[self.weight_bits]
verify_marlin_supported(self.quant_type, verify_marlin_supported(self.quant_type,
group_size=self.group_size, group_size=self.group_size,
...@@ -98,10 +104,12 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -98,10 +104,12 @@ class AWQMarlinConfig(QuantizationConfig):
return None return None
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMarlinLinearMethod"]: prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -271,4 +279,182 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -271,4 +279,182 @@ class AWQMarlinLinearMethod(LinearMethodBase):
quant_type=self.quant_config.quant_type, quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition, output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition, input_size_per_partition=layer.input_size_per_partition,
bias=bias) bias=bias)
\ No newline at end of file
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
True,
"quant_method":
FusedMoeWeightScaleSupported.GROUP.value,
})
w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = Parameter(torch.empty(num_experts,
num_groups_w2,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
)
...@@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -498,14 +498,14 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
x, x,
layer.w13_weight_packed, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits, router_logits,
layer.w13_g_idx,
layer.w2_g_idx,
layer.w13_g_idx_sort_indices,
layer.w2_g_idx_sort_indices,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale=layer.w13_weight_scale, g_idx1=layer.w13_g_idx,
w2_scale=layer.w2_weight_scale, g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits, num_bits=self.num_bits,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment