Commit fe2e2705 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev_rms_rope' into 'v0.9.2-dev'

feat(moe/marlin): Marlin W16A16 MoE 自动探测并预打包(去掉手动开关)

See merge request dcutoolkit/deeplearing/vllm!382
parents bb3afd68 de588fab
...@@ -200,7 +200,6 @@ if TYPE_CHECKING: ...@@ -200,7 +200,6 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False VLLM_USE_FUSED_CACHE_QUANT_BMM_MLA: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
...@@ -1311,10 +1310,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1311,10 +1310,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat # vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT": "VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......
...@@ -1681,93 +1681,88 @@ def fused_experts_impl( ...@@ -1681,93 +1681,88 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when # Optional fast path: use Marlin W16A16 fused MoE implementation when the
# explicitly requested. When weights are pre-packed in the post-load hook, # expert weights are already packed in Marlin layout.
# w1/w2 are already in Marlin layout and we can avoid first-run packing if not use_nn_moe:
# peaks during KV cache profiling. K = hidden_states.size(1)
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try: def _is_marlin_w16a16_packed(w1: torch.Tensor,
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501 w2: torch.Tensor) -> bool:
fused_experts_impl_w16a16_marlin) if w1.dim() != 3 or w2.dim() != 3:
except Exception: return False
fused_experts_impl_w16a16_marlin = None # type: ignore if w1.size(0) != w2.size(0):
return False
if fused_experts_impl_w16a16_marlin is not None: k_div16 = w1.size(1)
K = hidden_states.size(1) if k_div16 * 16 != K:
return False
def _is_marlin_w16a16_packed(w1: torch.Tensor, if w1.size(2) % 16 != 0:
w2: torch.Tensor) -> bool: return False
if w1.dim() != 3 or w2.dim() != 3: twoN = w1.size(2) // 16
return False if twoN % 2 != 0:
if w1.size(0) != w2.size(0): return False
return False N = twoN // 2
k_div16 = w1.size(1) if w2.size(2) != K * 16:
if k_div16 * 16 != K: return False
return False if w2.size(1) * 16 != N:
if w1.size(2) % 16 != 0: return False
return False return True
twoN = w1.size(2) // 16
if twoN % 2 != 0: is_packed = (getattr(w1, "marlin_w16a16_packed", False)
return False or getattr(w2, "marlin_w16a16_packed", False)
N = twoN // 2 or _is_marlin_w16a16_packed(w1, w2))
if w2.size(2) != K * 16: if is_packed:
return False try:
if w2.size(1) * 16 != N: from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
return False fused_experts_impl_w16a16_marlin)
return True except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False) if fused_experts_impl_w16a16_marlin is None:
or _is_marlin_w16a16_packed(w1, w2)): raise RuntimeError(
E = w1.size(0) "Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
if global_num_experts == -1: "Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
) )
# No fallback packing: require pre-packed weights when Marlin W16A16 if activation != "silu":
# MoE is enabled. If weights are still in the original layout, fail raise RuntimeError(
# fast to avoid packing-induced peak memory and unpredictable "Marlin W16A16 MoE only supports activation='silu'.")
# warmup/profiling behavior. if apply_router_weight_on_input:
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0) raise RuntimeError(
and w2.size(1) == K): "Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
twoN = w1.size(1) )
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0) E = w1.size(0)
and (twoN % 32 == 0)): if global_num_experts == -1:
raise RuntimeError( global_num_experts = E
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights " twoN = w1.size(2) // 16
"during the post-load hook or disable " if envs.VLLM_USE_GLOBAL_CACHE13:
"VLLM_USE_MARLIN_W16A16_MOE." cache13 = get_moe_cache(top_k_num,
) twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes. # Non-Marlin paths need the original weight shapes.
if use_nn_moe: if use_nn_moe:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os import os
import importlib import importlib
...@@ -75,6 +76,66 @@ else: ...@@ -75,6 +76,66 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
if not envs.VLLM_USE_LIGHTOP:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
# Global auxilary stream for running operations in background streams. # Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams # We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane). # for every layer (and make profiling look sane).
...@@ -83,6 +144,7 @@ logger = init_logger(__name__) ...@@ -83,6 +144,7 @@ logger = init_logger(__name__)
# - MoE shared_expert overlap with router # - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None _aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None: def aux_stream() -> torch.cuda.Stream | None:
""" """
Ensures aux_stream is initialized only once Ensures aux_stream is initialized only once
...@@ -406,12 +468,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -406,12 +468,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the # If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout. # post-load hook and replace parameters with the packed layout.
# #
# This avoids first-run packing peaks during KV cache profiling and # This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state. # keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike() if (getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False) and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)): and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight w1 = layer.w13_weight
...@@ -420,12 +483,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -420,12 +483,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
and w1.dtype in (torch.float16, torch.bfloat16) and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)): and w2.dtype in (torch.float16, torch.bfloat16)):
try: try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size( if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0): 0):
raise RuntimeError("Unexpected MoE weight shapes") raise RuntimeError("Unexpected MoE weight shapes")
...@@ -989,9 +1046,25 @@ class FusedMoE(torch.nn.Module): ...@@ -989,9 +1046,25 @@ class FusedMoE(torch.nn.Module):
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
))
self.use_nn_moe = int(os.environ.get("MOE_NN", 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else: else:
self.use_nn_moe = False self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment