Commit e7dee10f authored by yangql's avatar yangql Committed by zhangzbb
Browse files

[Feature]: 新增w4a8的moe-aiter算子的支持,采用VLLM_ROCM_USE_AITER_MOE环境变量控制,默认关闭,打开后会走到aiter的w4a8 moe的算子。

parent d49bafc5
...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp): ...@@ -733,7 +733,8 @@ class FusedMoE(CustomOp):
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod", if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod",
"SlimQuantW4A8Int8MoEMethod", "SlimQuantW4A8Int8MoEMethod",
"SlimQuantW4A8Int8MarlinMoEMethod")): "SlimQuantW4A8Int8MarlinMoEMethod",
"SlimQuantW4A8Int8AiterMoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
......
...@@ -25,6 +25,21 @@ import os ...@@ -25,6 +25,21 @@ import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from aiter.ops.shuffle import w4a8_moe_layout_shuffle_gemm1,w4a8_moe_layout_shuffle_gemm2
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter import dtypes, ActivationType
except ImportError as e:
print("Import error msg: import aiter")
W8A8_TRITONJSON=W8a8GetCacheJSON() W8A8_TRITONJSON=W8a8GetCacheJSON()
def baseline_scaled_mm(a: torch.Tensor, def baseline_scaled_mm(a: torch.Tensor,
...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig): ...@@ -82,7 +97,10 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -328,4 +346,209 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -328,4 +346,209 @@ class SlimQuantW4A8Int8MoEMethod:
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
)
class SlimQuantW4A8Int8AiterMoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
self.moe_quant_config = FusedMoEQuantConfig.make(
torch.int8,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
weight_dtype='int4'
)
return self.moe_quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def repack_and_shuffle_w4a8(self, weight_data, E):
"""
逐 expert 处理 [n, k_half]
处理完直接写回 weight_data[i]
"""
# 原始 shape: [E, n, k_half]
for i in range(E):
# 1. 取当前 expert [n, k_half]
expert = weight_data[i]
n, k_half = expert.shape
# 2. repack 逻辑(连续 → blocked)
w_u8 = expert.to(torch.uint8)
# 解包 1byte → 2个4bit
w_unpacked = torch.stack([
(w_u8 >> 4) & 0x0F,
w_u8 & 0x0F
], dim=-1).view(n, -1)
# 8个4bit分块重排
blocks = w_unpacked.view(n, -1, 8)
w_low = blocks[..., :4]
w_high = blocks[..., 4:]
packed = (w_low << 4) | w_high
packed = packed.view(n, k_half)
# 3. shuffle
w_marlin_in = w4a8_moe_layout_shuffle_gemm2(packed)
w_marlin_in = w_marlin_in.reshape(n, k_half)
# 4. 直接写回
weight_data[i] = w_marlin_in
return weight_data
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
layer.w13_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w13_weight.data, E), requires_grad=False)
layer.w2_weight = Parameter(self.repack_and_shuffle_w4a8(layer.w2_weight.data, E), requires_grad=False)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
shared_output: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
E = layer.w13_weight.size(0)
K = x.size(-1)
N1 = layer.w13_weight.size(1)
if x.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
M = x.size(0)
else:
assert x.dim() == 3
assert x.size(0) == E, f"{x.size(0)} == {E}"
M = x.size(1)
topk = topk_ids.size(1)
status, moe_cfg = get_aiter_moe_config(
M=M,
E=E,
N1=N1,
N2=N1//2,
K=K,
top_k=topk,
block_size=None,
dtype=dtypes.bf16,
quant_type=MoeQuantType.W4A8,
)
if not status:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
logger.info(f"[get_config_w4a8] {M=}, no solution found")
return aiter_moe(
x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
activation="silu",
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
global_num_experts=E,
expert_map=None,
)
\ No newline at end of file
...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod, SlimQuantW4A8Int8AiterMoEMethod
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -111,7 +111,10 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config) if envs.VLLM_ROCM_USE_AITER_MOE:
return SlimQuantW4A8Int8AiterMoEMethod(self, layer.moe_config)
else:
return SlimQuantW4A8Int8MarlinMoEMethod(self, layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment