Commit 3f5983bf authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-tc_opt' into 'v0.9.2-dev'

fix: 修复 expanded sampling metadata 对 numpy/array-like 输入不兼容导致崩溃   perf(fused-moe): 预打包 Marlin W16A16 MoE 权重,降低 warmup 显存峰值

See merge request dcutoolkit/deeplearing/vllm!357
parents 62f14ebf bfaac804
...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop( ...@@ -214,8 +214,6 @@ def moe_align_block_size_lightop(
def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_marlin: torch.Tensor, w1_marlin: torch.Tensor,
w2_marlin: torch.Tensor, w2_marlin: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -234,8 +232,8 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
): ):
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1_marlin.is_contiguous(), "Packed weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2_marlin.is_contiguous(), "Packed weights2 must be contiguous"
# 当前只支持 bf16 fp16 # 当前只支持 bf16 fp16
assert hidden_states.dtype in [torch.bfloat16,torch.float16] assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype compute_type = hidden_states.dtype
...@@ -243,12 +241,25 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor, ...@@ -243,12 +241,25 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE") "only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
# Packed weights store the same number of elements as the original layout,
# but reshaped/reordered for Marlin kernels:
# - w1_marlin: [E, K/16, (2N)*16]
# - w2_marlin: [E, N/16, K*16]
E, k_div16, twoN_times16 = w1_marlin.shape
K_w1 = k_div16 * 16
assert K_w1 == K, f"w1_marlin K mismatch: {K_w1} vs {K}"
assert twoN_times16 % 16 == 0
twoN = twoN_times16 // 16
assert twoN % 2 == 0
N = twoN // 2 N = twoN // 2
E2, K_w2, N2_w2 = w2.shape E2, n_div16, k_times16 = w2_marlin.shape
assert E2 == E, f"w2_marlin E mismatch: {E2} vs {E}"
K_w2 = k_times16 // 16
assert K_w2 == K, f"w2_marlin K mismatch: {K_w2} vs {K}"
assert n_div16 * 16 == N, f"w2_marlin N mismatch: {n_div16 * 16} vs {N}"
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
......
...@@ -5,7 +5,7 @@ import functools ...@@ -5,7 +5,7 @@ import functools
import json import json
import os import os
import math import math
from typing import Any, Callable, Dict, Optional, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -47,26 +47,6 @@ from vllm.utils import direct_register_custom_op ...@@ -47,26 +47,6 @@ from vllm.utils import direct_register_custom_op
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
...@@ -1694,6 +1674,102 @@ def fused_experts_impl( ...@@ -1694,6 +1674,102 @@ def fused_experts_impl(
i_s: Optional[torch.Tensor] = None, **_ i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is not None:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)):
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# No fallback packing: require pre-packed weights when Marlin W16A16
# MoE is enabled. If weights are still in the original layout, fail
# fast to avoid packing-induced peak memory and unpredictable
# warmup/profiling behavior.
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0)
and w2.size(1) == K):
twoN = w1.size(1)
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0)
and (twoN % 32 == 0)):
raise RuntimeError(
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights "
"during the post-load hook or disable "
"VLLM_USE_MARLIN_W16A16_MOE."
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe: if use_nn_moe:
E, _, N = w1.size() E, _, N = w1.size()
else: else:
...@@ -1702,67 +1778,18 @@ def fused_experts_impl( ...@@ -1702,67 +1778,18 @@ def fused_experts_impl(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) cache13 = get_moe_cache(top_k_num,
N,
K if not use_nn_moe else w2.shape[2],
device=hidden_states.device,
dtype=hidden_states.dtype)
else: else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
# Optional fast path: use lmslim's Marlin W16A16 fused MoE implementation device=hidden_states.device,
# when explicitly requested. This reuses the same cache13 buffer as other dtype=hidden_states.dtype)
# fused paths for consistency.
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output
)
if use_int8_w8a8 is True: if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
......
...@@ -406,6 +406,86 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -406,6 +406,86 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
w2 = layer.w2_weight
if (w1.is_cuda and w2.is_cuda
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
twoN, K = w1.size(1), w1.size(2)
if w2.size(1) != K:
raise RuntimeError("Unexpected MoE w2 layout")
N = w2.size(2)
if twoN != 2 * N:
raise RuntimeError("Unexpected MoE hidden dims")
if (K % 16 != 0 or K % 32 != 0 or N % 16 != 0
or twoN % 32 != 0):
raise RuntimeError("Marlin packing requires alignment")
from vllm.model_executor.layers.fused_moe.marlin_quant import (
w16a16_marlin_weight)
from torch.nn.parameter import Parameter
def _pack_per_expert(weight: torch.Tensor) -> torch.Tensor:
num_experts = weight.shape[0]
packed0 = w16a16_marlin_weight(
weight[0]).contiguous()
packed = packed0.new_empty((num_experts, ) +
packed0.shape)
packed[0].copy_(packed0)
del packed0
for i in range(1, num_experts):
tmp = w16a16_marlin_weight(
weight[i]).contiguous()
packed[i].copy_(tmp)
del tmp
return packed
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
# If packing dependencies are unavailable, fall back to the
# standard (non-Marlin) layouts.
pass
# Padding the weight for better performance on ROCm # Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
...@@ -1905,4 +1985,4 @@ direct_register_custom_op( ...@@ -1905,4 +1985,4 @@ direct_register_custom_op(
mutates_args=["hidden_states", "hidden_states_copy"], mutates_args=["hidden_states", "hidden_states_copy"],
fake_impl=moe_forward_shared_fake, fake_impl=moe_forward_shared_fake,
tags=(torch.Tag.needs_fixed_stride_order,), tags=(torch.Tag.needs_fixed_stride_order,),
) )
\ No newline at end of file
...@@ -735,22 +735,35 @@ class InputBatch: ...@@ -735,22 +735,35 @@ class InputBatch:
self, repeat_counts: torch.Tensor self, repeat_counts: torch.Tensor
) -> SamplingMetadata: ) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts # `repeat_counts` is expected to be a CPU torch tensor, but some
# call sites may pass a NumPy array (or other array-likes). Normalize
# to a CPU tensor to keep downstream ops (e.g. repeat_interleave)
# consistent and avoid hard crashes.
if isinstance(repeat_counts, torch.Tensor):
repeat_counts_cpu = repeat_counts.to(device="cpu")
else:
repeat_counts_cpu = torch.as_tensor(repeat_counts, device="cpu")
all_greedy = self.all_greedy all_greedy = self.all_greedy
all_random = self.all_random all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep # For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact. # rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu( def _expand_cpu_to_gpu(
t: Optional[torch.Tensor], t: Optional[object],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if t is None: if t is None:
return None return None
base = t[:num_reqs] # `t` should be a CPU torch tensor, but can be a NumPy array view
if repeat_counts_cpu is not None: # (e.g. created via `tensor.numpy()`). Convert if needed.
base = base.repeat_interleave(repeat_counts_cpu, dim=0) if isinstance(t, torch.Tensor):
base = t[:num_reqs]
elif isinstance(t, np.ndarray):
base = torch.from_numpy(t[:num_reqs])
else:
base = torch.as_tensor(t, device="cpu")[:num_reqs]
base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device, return base.to(device=self.device,
dtype=dtype if dtype is not None else None, dtype=dtype if dtype is not None else None,
non_blocking=True) non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment