Commit 0bf89b0c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_tc_opt' into 'v0.11.0-dev'

feat(moe):新增 VLLM_USE_MOE_W16A16_TRTION 强制 Triton MoE

See merge request dcutoolkit/deeplearing/vllm!397
parents 0946f6c9 57eb1192
......@@ -247,6 +247,7 @@ if TYPE_CHECKING:
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_W8A8_BACKEND: int = 3
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1704,6 +1705,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1984,6 +1984,11 @@ def fused_experts_impl(
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0.")
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
......
......@@ -120,15 +120,17 @@ def _is_marlin_w16a16_moe_supported(
return False
try:
import lmslim.envs as lsenvs
from lightop import get_moe_cuda_marlin_config_w16a16
device_name = lsenvs.LMSLIM_GPU_NAME
if not device_name:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
num_cus = torch.cuda.get_device_properties(
torch.cuda.current_device()).multi_processor_count
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
......@@ -139,8 +141,8 @@ def _is_marlin_w16a16_moe_supported(
K,
N,
top_k,
device_name,
num_cus,
arch_name,
arch_cu,
dtype,
)
if not status:
......@@ -1304,7 +1306,9 @@ class FusedMoE(CustomOp):
if quant_config is None:
# Not considering quant for now, temporarily
self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and not self.moe_config.has_bias
not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and not self.moe_config.has_bias
and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment