Commit 402c8b1e authored by laibao's avatar laibao
Browse files

perf(fused-moe): 预打包 Marlin W16A16 MoE 权重,降低 warmup 显存峰值

在 post-load hook 中对 w13/w2 做 per-expert Marlin pack,并替换为 packed 参数
Marlin fast path 仅接受 packed 权重;未预打包则 fail fast,避免运行时 packing 峰值/不确定性
更新 Marlin wrapper 的入参与 shape 推导(从 packed layout 计算 K/N)
parent b949b805
...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop( ...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor, w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor, w2_marlin: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
): ):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16 # 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16] assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype compute_type = hidden_states.dtype
...@@ -243,12 +241,25 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -243,12 +241,25 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE") "only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2 N = twoN // 2
E2, K_w2, N2_w2 = w2.shape E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -9,7 +9,7 @@ import math ...@@ -9,7 +9,7 @@ import math
# torch.compile needs typing.List. It will fail torch.library.infer_schema # torch.compile needs typing.List. It will fail torch.library.infer_schema
# otherwise # otherwise
from typing import List # noqa: UP035 from typing import List # noqa: UP035
from typing import Any, Callable, Optional, Union, Dict, Tuple from typing import Any, Callable, Optional, Union, Dict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -59,26 +59,6 @@ logger = init_logger(__name__) ...@@ -59,26 +59,6 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
...@@ -1968,6 +1948,87 @@ def fused_experts_impl( ...@@ -1968,6 +1948,87 @@ def fused_experts_impl(
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. # Check constraints.
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop/lmslim is installed and LMSLIM_USE_LIGHTOP=1."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_nn_moe: if use_nn_moe:
E, _, N = w1.size() E, _, N = w1.size()
else: else:
...@@ -1976,61 +2037,18 @@ def fused_experts_impl( ...@@ -1976,61 +2037,18 @@ def fused_experts_impl(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else: else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(M * top_k_num *
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin max(N, K if not use_nn_moe else w2.shape[2]),
if (envs.VLLM_USE_MARLIN_W16A16_MOE device=hidden_states.device,
and fused_experts_impl_w16a16_marlin is not None): dtype=hidden_states.dtype)
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_int8_w8a8 is True: if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -441,6 +441,86 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -441,6 +441,86 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (w1.is_cuda and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if (K % 16 != 0 or K % 32 != 0 or N % 16 != 0
or twoN % 32 != 0):
raise RuntimeError("Marlin packing requires alignment")
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight)
from torch.nn.parameter import Parameter
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(
weight[0]).contiguous()
packed = packed0.new_empty((num_experts, ) +
packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(
weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm # Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment