Commit f35ea024 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-tc_opt' into 'v0.9.2-dev'

feat(moe):新增 VLLM_USE_MOE_W16A16_TRTION 强制 Triton MoE

See merge request dcutoolkit/deeplearing/vllm!396
parents 19d458ec cedfe391
...@@ -216,6 +216,7 @@ if TYPE_CHECKING: ...@@ -216,6 +216,7 @@ if TYPE_CHECKING:
VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0 VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS: int = 0
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT: int = -1
VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1 VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT: int = -1
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1383,6 +1384,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1383,6 +1384,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Only capture when num_tokens < N (0 disables). # Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT": "VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")), lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1711,6 +1711,11 @@ def fused_experts_impl( ...@@ -1711,6 +1711,11 @@ def fused_experts_impl(
or getattr(w2, "marlin_w16a16_packed", False) or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)) or _is_marlin_w16a16_packed(w1, w2))
if is_packed: if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try: try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin) fused_experts_impl_w16a16_marlin)
......
...@@ -101,9 +101,6 @@ def _is_marlin_w16a16_moe_supported( ...@@ -101,9 +101,6 @@ def _is_marlin_w16a16_moe_supported(
return False return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0: if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False return False
if not envs.VLLM_USE_LIGHTOP:
return False
try: try:
from lightop import get_moe_cuda_marlin_config_w16a16 from lightop import get_moe_cuda_marlin_config_w16a16
...@@ -1051,7 +1048,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1051,7 +1048,9 @@ class FusedMoE(torch.nn.Module):
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
moe_in_dtype = model_dtype moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = ( self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu" not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and self.activation == "silu"
and not self.apply_router_weight_on_input and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported( and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts, E=self.local_num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment