Commit bc60d70d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev-marlin_w16a16' into 'v0.15.1-dev'

feat(moe): 补齐 v0.15 中 Marlin W16A16 MoE 端到端接入

See merge request dcutoolkit/deeplearing/vllm!431
parents 7e4ee060 1e2fe58f
......@@ -292,6 +292,9 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root():
......@@ -1842,6 +1845,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'True').lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1706,6 +1706,112 @@ def fused_experts_impl(
) -> torch.Tensor:
# Check constraints.
num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor, w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (
getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)
)
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0."
)
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import (
fused_experts_impl_w16a16_marlin,
)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop/lmslim is installed and LMSLIM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'."
)
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
if w1_bias is not None or w2_bias is not None:
raise RuntimeError(
"Marlin W16A16 MoE does not support expert biases."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(
top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
cache13 = torch.empty(
M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_nn_moe:
E, _, N = w1.size()
else:
......@@ -1714,12 +1820,6 @@ def fused_experts_impl(
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os
from collections.abc import Callable, Iterable
from contextlib import nullcontext
......@@ -72,6 +73,64 @@ from vllm.model_executor.layers.fused_moe.fused_moe import is_power_of_two
logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
......@@ -632,9 +691,28 @@ class FusedMoE(CustomOp):
if quant_config is None:
# Not considering quant for now, temporarily
self._marlin_w16a16_moe_enabled = (
not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and not self.moe_config.has_bias
and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
)
)
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else:
self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = {
"num_experts": self.local_num_experts,
......@@ -671,6 +749,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
# If this layer is configured for Marlin W16A16 path, we intentionally
# keep the monolithic execution route so runtime can dispatch to
# fused_experts_impl_w16a16_marlin when weights are packed.
if getattr(self, "_marlin_w16a16_moe_enabled", False):
return
self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
......
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
......@@ -230,6 +231,87 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)
):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (
w1.is_cuda
and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)
):
try:
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if K % 32 != 0 or N % 16 != 0:
raise RuntimeError("Marlin packing requires alignment")
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight,
)
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(weight[0]).contiguous()
packed = packed0.new_empty((num_experts,) + packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
......@@ -315,6 +397,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
and getattr(layer, "_marlin_w16a16_moe_packed", False)
):
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.get_fused_moe_quant_config(layer),
use_nn_moe=use_nn_moe,
)
assert self.kernel is not None
return self.kernel(
hidden_states=x,
......
......@@ -341,13 +341,13 @@ class Qwen3MoeAttention(nn.Module):
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
......@@ -371,13 +371,13 @@ class Qwen3MoeAttention(nn.Module):
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
......@@ -485,9 +485,9 @@ class Qwen3MoeAttention(nn.Module):
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
self.q_norm.variance_epsilon,
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment