Commit fe2e2705 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_rms_rope' into 'v0.9.2-dev'

feat(moe/marlin): Marlin W16A16 MoE 自动探测并预打包(去掉手动开关)

See merge request dcutoolkit/deeplearing/vllm!382
parents bb3afd68 de588fab
......@@ -200,7 +200,6 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
......@@ -1311,10 +1310,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......
......@@ -1681,93 +1681,88 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is not None:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)):
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
# No fallback packing: require pre-packed weights when Marlin W16A16
# MoE is enabled. If weights are still in the original layout, fail
# fast to avoid packing-induced peak memory and unpredictable
# warmup/profiling behavior.
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0)
and w2.size(1) == K):
twoN = w1.size(1)
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0)
and (twoN % 32 == 0)):
raise RuntimeError(
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights "
"during the post-load hook or disable "
"VLLM_USE_MARLIN_W16A16_MOE."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os
import importlib
......@@ -75,6 +76,66 @@ else:
logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
if not envs.VLLM_USE_LIGHTOP:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
......@@ -83,6 +144,7 @@ logger = init_logger(__name__)
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
......@@ -406,12 +468,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
if (getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
......@@ -420,12 +483,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
......@@ -989,9 +1046,25 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
))
self.use_nn_moe = int(os.environ.get("MOE_NN", 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else:
self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = {
"num_experts": self.local_num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment