"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "e3dc17930426f475a47c45be3ae3d973a070586b"
Commit 8cbcac5d authored by zhuwenwen's avatar zhuwenwen
Browse files

set VLLM_USE_MARLIN_W16A16_MOE=0 on bw

parent f1bc9890
...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop( ...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor, w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor, w2_marlin: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
): ):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16 # 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16] assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype compute_type = hidden_states.dtype
...@@ -243,11 +241,24 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -243,11 +241,24 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE") "only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2 N = twoN // 2
E2, K_w2, N2_w2 = w2.shape E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -15,11 +15,11 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -15,11 +15,11 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.utils import SUPPORT_MOE_MARLIN_W16A16 # from vllm.utils import SUPPORT_MOE_MARLIN_W16A16
if SUPPORT_MOE_MARLIN_W16A16: # if SUPPORT_MOE_MARLIN_W16A16:
os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1' # os.environ['VLLM_USE_MARLIN_W16A16_MOE'] = '1'
os.environ['MOE_NN'] = '0' # os.environ['MOE_NN'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig from vllm.attention.selector import AttentionSelectorConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment