Unverified Commit a091e2da authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)


Co-authored-by: default avatarDipika <dipikasikka1@gmail.com>
parent fc990f97
This diff is collapsed.
...@@ -2,11 +2,14 @@ ...@@ -2,11 +2,14 @@
#include <torch/all.h> #include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe( torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm, const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights); bool replicate_input, bool apply_weights);
...@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "g_idx, Tensor! perm, Tensor! workspace, "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"bool replicate_input, bool apply_weights) -> Tensor"); "int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif #endif
} }
......
...@@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref): ...@@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe( def test_fused_marlin_moe(
m: int, m: int,
n: int, n: int,
...@@ -148,6 +149,7 @@ def test_fused_marlin_moe( ...@@ -148,6 +149,7 @@ def test_fused_marlin_moe(
topk: int, topk: int,
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int,
): ):
torch.manual_seed(7) torch.manual_seed(7)
...@@ -161,13 +163,12 @@ def test_fused_marlin_moe( ...@@ -161,13 +163,12 @@ def test_fused_marlin_moe(
if group_size in (k, n): if group_size in (k, n):
return return
quant_type = scalar_types.uint4b8 quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16 dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
w_ref1_l = [] w_ref1_l = []
qweight1_l = [] qweight1_l = []
...@@ -240,6 +241,7 @@ def test_fused_marlin_moe( ...@@ -240,6 +241,7 @@ def test_fused_marlin_moe(
topk_ids, topk_ids,
w1_scale=scales1, w1_scale=scales1,
w2_scale=scales2, w2_scale=scales2,
num_bits=num_bits,
) )
assert compute_max_diff(marlin_output, triton_output) < 4e-2 assert compute_max_diff(marlin_output, triton_output) < 4e-2
...@@ -254,7 +256,8 @@ def test_fused_marlin_moe( ...@@ -254,7 +256,8 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm( @pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply(
m: int, m: int,
n: int, n: int,
k: int, k: int,
...@@ -262,6 +265,7 @@ def test_marlin_moe_mmm( ...@@ -262,6 +265,7 @@ def test_marlin_moe_mmm(
topk: int, topk: int,
group_size: int, group_size: int,
act_order: bool, act_order: bool,
num_bits: int,
): ):
if topk > e: if topk > e:
return return
...@@ -273,7 +277,8 @@ def test_marlin_moe_mmm( ...@@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
if group_size == k: if group_size == k:
return return
quant_type = scalar_types.uint4b8 quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16 dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
...@@ -308,7 +313,8 @@ def test_marlin_moe_mmm( ...@@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
g_idx, g_idx,
sort_indices, sort_indices,
topk, topk,
renormalize=False) renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2 assert compute_max_diff(marlin_output, torch_output) < 1e-2
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
\ No newline at end of file gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
File mode changed from 100644 to 100755
...@@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -559,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
num_bits: int) -> torch.Tensor: num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0] num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0 assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * 2), output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device, device=b_q_weight.device,
dtype=b_q_weight.dtype) dtype=b_q_weight.dtype)
for e in range(num_experts): for e in range(num_experts):
......
...@@ -7,18 +7,21 @@ import torch ...@@ -7,18 +7,21 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config) fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
def single_marlin_moe( def single_marlin_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w: torch.Tensor, w: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
g_idx: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, perm: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
) -> torch.Tensor:
""" """
This function computes the multiplication of hidden_states with expert This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism. weights used in Marlin MoE, using weights w and top-k gating mechanism.
...@@ -36,6 +39,7 @@ def single_marlin_moe( ...@@ -36,6 +39,7 @@ def single_marlin_moe(
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- override_config (Optional[Dict[str, Any]]): Optional override - override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration. for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -48,10 +52,11 @@ def single_marlin_moe( ...@@ -48,10 +52,11 @@ def single_marlin_moe(
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape M, K = hidden_states.shape
E = w.shape[0] E = w.shape[0]
N = w.shape[2] // 2 N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize) renormalize)
...@@ -76,10 +81,13 @@ def single_marlin_moe( ...@@ -76,10 +81,13 @@ def single_marlin_moe(
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
False) block_size_m, True, False)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
...@@ -98,6 +106,7 @@ def fused_marlin_moe( ...@@ -98,6 +106,7 @@ def fused_marlin_moe(
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -122,6 +131,7 @@ def fused_marlin_moe( ...@@ -122,6 +131,7 @@ def fused_marlin_moe(
w1. w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2. w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -131,13 +141,14 @@ def fused_marlin_moe( ...@@ -131,13 +141,14 @@ def fused_marlin_moe(
0], "Number of tokens mismatch" 0], "Number of tokens mismatch"
assert hidden_states.shape[ assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1" 1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[ assert hidden_states.shape[1] == w2.shape[2] // (
1] == w2.shape[2] // 2, "Hidden size mismatch w2" num_bits // 2), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16 assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape M, K = hidden_states.shape
E = w1.shape[0] E = w1.shape[0]
...@@ -165,6 +176,9 @@ def fused_marlin_moe( ...@@ -165,6 +176,9 @@ def fused_marlin_moe(
device="cuda", device="cuda",
requires_grad=False) requires_grad=False)
scalar_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
...@@ -181,6 +195,7 @@ def fused_marlin_moe( ...@@ -181,6 +195,7 @@ def fused_marlin_moe(
g_idx1, g_idx1,
perm1, perm1,
workspace, workspace,
scalar_type,
M, M,
2 * N, 2 * N,
K, K,
...@@ -204,6 +219,7 @@ def fused_marlin_moe( ...@@ -204,6 +219,7 @@ def fused_marlin_moe(
g_idx2, g_idx2,
perm2, perm2,
workspace, workspace,
scalar_type,
M, M,
K, K,
N, N,
......
...@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -445,7 +445,7 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
......
...@@ -6,6 +6,8 @@ import torch ...@@ -6,6 +6,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat) CompressionFormat)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -38,10 +40,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -38,10 +40,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if not (self.quant_config.quant_format if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value == CompressionFormat.pack_quantized.value
and self.num_bits == 4): and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ", raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ", f"{CompressionFormat.pack_quantized.value} ",
"is supported for 4 bits") "is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int, hidden_size: int, intermediate_size: int,
...@@ -292,4 +295,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -292,4 +295,5 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
num_bits=self.num_bits,
) )
...@@ -611,4 +611,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -611,4 +611,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
w1_scale=layer.w13_scales, w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales, w2_scale=layer.w2_scales,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype) ).to(orig_dtype)
...@@ -23,13 +23,7 @@ def get_model_architecture( ...@@ -23,13 +23,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
mixtral_supported = ["fp8", "compressed-tensors"] mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"]
# for gptq_marlin, only run fused MoE for int4
if model_config.quantization == "gptq_marlin":
hf_quant_config = getattr(model_config.hf_config,
"quantization_config", None)
if hf_quant_config and hf_quant_config.get("bits") == 4:
mixtral_supported.append("gptq_marlin")
if (model_config.quantization is not None if (model_config.quantization is not None
and model_config.quantization not in mixtral_supported and model_config.quantization not in mixtral_supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment