Commit 3e197b3a authored by guanyu1's avatar guanyu1
Browse files

0151-qwen3_5 aiter moe接入,但是还没有验证精度

parent bcb2ba6c
......@@ -906,6 +906,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
),
# Control whether to use fused MoE activation chunking. Current chunking
# logic is incompatible with torch.compile and causes IMA. See issue
# https://github.com/vllm-project/vllm/issues/19631.
......
......@@ -10,7 +10,7 @@ import math
from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional
from vllm._aiter_ops import rocm_aiter_ops
import torch
import vllm.envs as envs
......@@ -1767,12 +1767,72 @@ def fused_experts_impl(
return False
return True
w1_aiter_shuffled = getattr(w1, "aiter_moe_shuffled", False)
w2_aiter_shuffled = getattr(w2, "aiter_moe_shuffled", False)
if w1_aiter_shuffled != w2_aiter_shuffled:
raise RuntimeError(
"MoE weights must either both be AITER-shuffled or both be "
"unshuffled."
)
is_aiter_shuffled = w1_aiter_shuffled and w2_aiter_shuffled
if is_aiter_shuffled:
if not (
current_platform.is_rocm()
and rocm_aiter_ops.is_fused_moe_enabled()
):
raise RuntimeError(
"AITER-shuffled MoE weights require ROCm AITER fused MoE "
"to be enabled."
)
try:
from aiter.fused_moe_asm_wna16 import fused_experts_asm_impl
except Exception as e:
raise RuntimeError(
"AITER-shuffled MoE weights were loaded, but the ASM "
"kernel is unavailable. Ensure the `aiter` package is "
"installed and exposes `fused_moe_asm_wna16`."
) from e
if activation != "silu":
raise RuntimeError(
"ASM Marlin W16A16 MoE only supports activation='silu'."
)
if apply_router_weight_on_input:
raise RuntimeError(
"ASM Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
if w1_bias is not None or w2_bias is not None:
raise RuntimeError(
"ASM Marlin W16A16 MoE does not support expert biases."
)
return fused_experts_asm_impl(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
dtype=hidden_states.dtype,
inplace=inplace,
activation=activation,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_shuffle=1,
)
is_packed = (
getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)
)
if is_packed:
if envs.VLLM_USE_MOE_W16A16_TRITON:
raise RuntimeError(
......
......@@ -283,28 +283,119 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
packed[i].copy_(tmp)
del tmp
return packed
def _asm_shuffle_weight_b8(x: torch.Tensor, stage: torch.int32 = 1) -> torch.Tensor:
# Hardcode BLOCK_K and BLOCK_N
assert x.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.int8, torch.float8_e4m3fn
]
if x.dtype == torch.int8 or x.dtype == torch.float8_e4m3fn:
N = 16
K = 16
IK = 64
IN = 64
BK = 256
BN = 128
if stage == 1:
if x.shape[-2] % 128 != 0 and x.shape[-2] % 64 == 0:
BN = 64
if stage == 2:
if x.shape[-1] % 128 == 0:
BK = 128
elif x.shape[-1] % 128 == 96:
BN = 64
BK = 64
assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }"
x_ = x
multiple = x.shape[-1] // BK * BK
part1 = x[:, :, :multiple]
### part1 shuffle
# 0, 1, 2, 3, 4, 5, 6, 7, 8
part1 = part1.view(-1, part1.shape[-2] // BN, BN // IN, IN // N, N, part1.shape[-1] // BK, BK // IK, IK // K, K)
part1 = part1.permute(0, 1, 5, 2, 6, 3, 7, 4, 8).contiguous()
part1 = part1.flatten(start_dim=1)
### part2 shuffle
part2 = x[:, :, multiple:]
IK = 32
BK = 32
# 0, 1, 2, 3, 4, 5, 6, 7, 8
part2 = part2.view(-1, part2.shape[-2] // BN, BN // IN, IN // N, N, part2.shape[-1] // BK, BK // IK, IK // K, K)
part2 = part2.permute(0, 1, 5, 2, 6, 3, 7, 4, 8).contiguous()
part2 = part2.flatten(start_dim=1)
### combine
x_ = torch.cat((part1, part2), dim=1)
x_ = x_.view(*x.shape)
return x_
elif x.dtype == torch.float16 or x.dtype == torch.bfloat16:
N = 16
K = 8
IK = 32
IN = 64
BK = 128
BN = 64
if stage == 2:
BK = 32
else:
assert False, f"not support {x.dtype}"
assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }"
assert x.shape[-1] % BK == 0, f"{x.shape[-1]} % {BK} == {x.shape[-1] % BK }"
x_ = x
# 0, 1, 2, 3, 4, 5, 6, 7, 8
x_ = x_.view(-1, x.shape[-2] // BN, BN // IN, IN // N, N, x.shape[-1] // BK, BK // IK, IK // K, K)
x_ = x_.permute(0, 1, 5, 2, 6, 3, 7, 4, 8)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
with torch.no_grad():
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
if current_platform.is_rocm() and rocm_aiter_ops.is_fused_moe_enabled():
replace_parameter(
layer,
"w13_weight",
_asm_shuffle_weight_b8(w1, stage=1),
)
replace_parameter(
layer,
"w2_weight",
_asm_shuffle_weight_b8(w2, stage=2),
)
new_w1 = layer.w13_weight
new_w2 = layer.w2_weight
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "aiter_moe_shuffled", True)
setattr(new_w2, "aiter_moe_shuffled", True)
layer._marlin_w16a16_moe_packed = True
else:
w1_packed = _pack_per_expert(w1)
w2_packed = _pack_per_expert(w2)
new_w1 = Parameter(w1_packed, requires_grad=False)
new_w2 = Parameter(w2_packed, requires_grad=False)
# Preserve any custom weight attributes (e.g. loaders).
if hasattr(w1, "__dict__"):
for k, v in w1.__dict__.items():
setattr(new_w1, k, v)
if hasattr(w2, "__dict__"):
for k, v in w2.__dict__.items():
setattr(new_w2, k, v)
setattr(new_w1, "marlin_w16a16_packed", True)
setattr(new_w2, "marlin_w16a16_packed", True)
layer.w13_weight = new_w1
layer.w2_weight = new_w2
layer._marlin_w16a16_moe_packed = True
return
except Exception:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment