Commit 8cbcac5d authored by zhuwenwen's avatar zhuwenwen
Browse files

set VLLM_USE_MARLIN_W16A16_MOE=0 on bw

parent f1bc9890
......@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor,
topk_weights: torch.Tensor,
......@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
......@@ -243,11 +241,24 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2
E2, K_w2, N2_w2 = w2.shape
E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1:
global_num_experts = E
......
......@@ -15,11 +15,11 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
# from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
if SUPPORT_MOE_MARLIN_W16A16:
os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
os.environ['MOE_NN'] = '0'
# if SUPPORT_MOE_MARLIN_W16A16:
# os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
# os.environ['MOE_NN'] = '0'
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment