Commit 57eb1192 authored by laibao's avatar laibao
Browse files

feat(moe):新增 VLLM_USE_MOE_W16A16_TRTION 强制 Triton MoE

增加环境变量开关,禁用 Marlin W16A16 MoE 路径
强制 Triton 且权重已是 Marlin packed 时给出明确报错
Marlin 支持探测改为 best-effort(不再依赖 VLLM_USE_LIGHTOP)
parent 8348926e
...@@ -247,6 +247,7 @@ if TYPE_CHECKING: ...@@ -247,6 +247,7 @@ if TYPE_CHECKING:
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1704,6 +1705,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1704,6 +1705,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default) # blaslt: 3 (default)
# rocblas: others # rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")), "VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1984,6 +1984,11 @@ def fused_experts_impl( ...@@ -1984,6 +1984,11 @@ def fused_experts_impl(
or getattr(w2, "marlin_w16a16_packed", False) or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)) or _is_marlin_w16a16_packed(w1, w2))
if is_packed: if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try: try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin) fused_experts_impl_w16a16_marlin)
......
...@@ -120,15 +120,17 @@ def _is_marlin_w16a16_moe_supported( ...@@ -120,15 +120,17 @@ def _is_marlin_w16a16_moe_supported(
return False return False
try: try:
import lmslim.envs as lsenvs
from lightop import get_moe_cuda_marlin_config_w16a16 from lightop import get_moe_cuda_marlin_config_w16a16
device_name = lsenvs.LMSLIM_GPU_NAME props = torch.cuda.get_device_properties(torch.cuda.current_device())
if not device_name: arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False return False
num_cus = torch.cuda.get_device_properties( arch_cu = props.multi_processor_count
torch.cuda.current_device()).multi_processor_count
twoN = 2 * N twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16( _, _, status = get_moe_cuda_marlin_config_w16a16(
...@@ -139,8 +141,8 @@ def _is_marlin_w16a16_moe_supported( ...@@ -139,8 +141,8 @@ def _is_marlin_w16a16_moe_supported(
K, K,
N, N,
top_k, top_k,
device_name, arch_name,
num_cus, arch_cu,
dtype, dtype,
) )
if not status: if not status:
...@@ -1304,7 +1306,9 @@ class FusedMoE(CustomOp): ...@@ -1304,7 +1306,9 @@ class FusedMoE(CustomOp):
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
self._marlin_w16a16_moe_enabled = ( self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and not self.moe_config.has_bias not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and not self.moe_config.has_bias
and self.activation == "silu" and self.activation == "silu"
and not self.apply_router_weight_on_input and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported( and _is_marlin_w16a16_moe_supported(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment