Commit 2b0c9835 authored by laibao's avatar laibao
Browse files

perf(fused-moe): 接入 W16A16 Marlin MoE 并缓存 pack 权重

 - fused_experts_impl 增加 VLLM_USE_MARLIN_W16A16_MOE fast path:首次对 w1/w2 做 Marlin pack 后缓存,避免重复 reorder;并将原始
    权重 offload 到 CPU,降低 GPU 双份驻留
  - envs 补齐环境变量 VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
  - 更新 fuse_moe_w16a16_marlin.py 的报错提示为 VLLM_USE_LIGHTOP=1
parent 8da572a9
......@@ -230,6 +230,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_ZEROS: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
......@@ -1627,6 +1628,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD",
"False").lower() in ("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
......
......@@ -240,7 +240,7 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
......
......@@ -59,6 +59,26 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
......@@ -1966,7 +1986,51 @@ def fused_experts_impl(
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment