Commit 1e2fe58f authored by laibao's avatar laibao
Browse files

feat(moe): 补齐 v0.15 中 Marlin W16A16 MoE 端到端接入

  参考并移植 011/vllm 的关键提交逻辑
  新增 VLLM_USE_MOE_W16A16_TRITON 开关,并接入基于 lightop 的运行时能力探测与启用结果缓存。
  在权重加载后对 w13 与 w2 执行 W16A16 Marlin 预打包。
  W16A16 Marlin 启用时保留 monolithic 执行路径,并在 fused_experts_impl 中增加 packed 权重 fast-path。
  保持 Marlin 或 lightop 不可用时的回退行为不变。
parent 8cdc3a30
...@@ -292,6 +292,9 @@ if TYPE_CHECKING: ...@@ -292,6 +292,9 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1842,6 +1845,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1842,6 +1845,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_REJECT_SAMPLE_OPT": "VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'True').lower() in lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'True').lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
} }
......
...@@ -1705,6 +1705,112 @@ def fused_experts_impl( ...@@ -1705,6 +1705,112 @@ def fused_experts_impl(
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor, w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (
getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)
)
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
"VLLM_USE_MOE_W16A16_TRITON=1 forces Triton MoE, but the MoE weights are "
"packed in Marlin W16A16 layout. Please load unpacked weights or set "
"VLLM_USE_MOE_W16A16_TRITON=0."
)
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import (
fused_experts_impl_w16a16_marlin,
)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop/lmslim is installed and LMSLIM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'."
)
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
if w1_bias is not None or w2_bias is not None:
raise RuntimeError(
"Marlin W16A16 MoE does not support expert biases."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(
top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
cache13 = torch.empty(
M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_nn_moe: if use_nn_moe:
E, _, N = w1.size() E, _, N = w1.size()
else: else:
...@@ -1713,12 +1819,6 @@ def fused_experts_impl( ...@@ -1713,12 +1819,6 @@ def fused_experts_impl(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else: else:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os import os
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
...@@ -72,6 +73,64 @@ from vllm.model_executor.layers.fused_moe.fused_moe import is_power_of_two ...@@ -72,6 +73,64 @@ from vllm.model_executor.layers.fused_moe.fused_moe import is_power_of_two
logger = init_logger(__name__) logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
class FusedMoeWeightScaleSupported(Enum): class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor" TENSOR = "tensor"
...@@ -632,9 +691,28 @@ class FusedMoE(CustomOp): ...@@ -632,9 +691,28 @@ class FusedMoE(CustomOp):
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
self._marlin_w16a16_moe_enabled = (
not envs.VLLM_USE_MOE_W16A16_TRITON
and params_dtype == moe_in_dtype
and not self.moe_config.has_bias
and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
)
)
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else: else:
self.use_nn_moe = False self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
...@@ -671,6 +749,12 @@ class FusedMoE(CustomOp): ...@@ -671,6 +749,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method. # should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# If this layer is configured for Marlin W16A16 path, we intentionally
# keep the monolithic execution route so runtime can dispatch to
# fused_experts_impl_w16a16_marlin when weights are packed.
if getattr(self, "_marlin_w16a16_moe_enabled", False):
return
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with # routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend. # DeepEP all2all backend.
......
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel, make_unquantized_moe_kernel,
select_unquantized_moe_backend, select_unquantized_moe_backend,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
...@@ -230,6 +231,87 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -230,6 +231,87 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)
):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (
w1.is_cuda
and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)
):
try:
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if K % 32 != 0 or N % 16 != 0:
raise RuntimeError("Marlin packing requires alignment")
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight,
)
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(weight[0]).contiguous()
packed = packed0.new_empty((num_experts,) + packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm # Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
...@@ -315,6 +397,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -315,6 +397,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
and getattr(layer, "_marlin_w16a16_moe_packed", False)
):
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.get_fused_moe_quant_config(layer),
use_nn_moe=use_nn_moe,
)
assert self.kernel is not None assert self.kernel is not None
return self.kernel( return self.kernel(
hidden_states=x, hidden_states=x,
......
...@@ -341,13 +341,13 @@ class Qwen3MoeAttention(nn.Module): ...@@ -341,13 +341,13 @@ class Qwen3MoeAttention(nn.Module):
def rms_rotary_embedding_fuse( def rms_rotary_embedding_fuse(
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor | None,
head_size: int, head_size: int,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox_style: bool, is_neox_style: bool,
q_weight: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, k_weight: torch.Tensor,
epsilon: float, epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None, q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None,
) -> None: ) -> None:
...@@ -371,13 +371,13 @@ class Qwen3MoeAttention(nn.Module): ...@@ -371,13 +371,13 @@ class Qwen3MoeAttention(nn.Module):
# k_out:torch.Tensor, # k_out:torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor | None,
head_size: int, head_size: int,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox_style: bool, is_neox_style: bool,
q_weight: torch.Tensor, q_weight: torch.Tensor,
k_weight: torch.Tensor, k_weight: torch.Tensor,
epsilon: float, epsilon: float,
key: torch.Tensor | None = None,
q_bias: torch.Tensor | None = None, q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None,
) -> None: ) -> None:
...@@ -485,9 +485,9 @@ class Qwen3MoeAttention(nn.Module): ...@@ -485,9 +485,9 @@ class Qwen3MoeAttention(nn.Module):
self.rotary_emb.is_neox_style, self.rotary_emb.is_neox_style,
self.q_norm.weight, self.q_norm.weight,
self.k_norm.weight, self.k_norm.weight,
self.q_norm.variance_epsilon,
None, None,
None, None,
self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr( elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None: self.rotary_emb, "mrope_section", None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment