Commit 2444e959 authored by lixh6's avatar lixh6
Browse files

[FEATRUE] 接入Aiter MoE W8A8-FP8 量化模型支持

parent aef3c487
...@@ -167,6 +167,7 @@ if TYPE_CHECKING: ...@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True
VLLM_USE_AITER_MOE_W8A8: bool = True
VLLM_DEEP_GEMM_WARMUP: Literal[ VLLM_DEEP_GEMM_WARMUP: Literal[
"skip", "skip",
"full", "full",
...@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1290,6 +1291,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1"))
), ),
"VLLM_USE_AITER_MOE_W8A8": lambda: bool(
int(os.getenv("VLLM_USE_AITER_MOE_W8A8", "1"))
),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no # JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine # JIT'ing in the hot-path. However, this warmup increases the engine
......
...@@ -6,7 +6,9 @@ import functools ...@@ -6,7 +6,9 @@ import functools
import json import json
import os import os
import math import math
import sys
import aiter
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeQuantType
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -1858,35 +1860,74 @@ def fused_experts_impl( ...@@ -1858,35 +1860,74 @@ def fused_experts_impl(
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 or use_fp8_w8a8: if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states, if envs.VLLM_USE_AITER_MOE_W8A8==True:
w1=w1, K_input = hidden_states.size(1)
w2=w2, actual_N2 = N // 2
topk_weights=topk_weights, quant_type = MoeQuantType.W8A8
topk_ids=topk_ids, status, moe_config = get_aiter_moe_config(
cache13=cache13, M=num_tokens,
inplace=inplace, E=global_num_experts,
activation=activation, N1=N,
apply_router_weight_on_input=apply_router_weight_on_input, N2=actual_N2,
use_fp8_w8a8=use_fp8_w8a8, K=K_input,
use_int8_w8a8=use_int8_w8a8, top_k=top_k_num,
use_int8_w8a16=False, block_size=0,
use_int4_w4a16=False, dtype=hidden_states.dtype,
per_channel_quant=per_channel_quant, quant_type=quant_type,
global_num_experts=global_num_experts, )
expert_map=expert_map,
w1_scale=w1_scale, output = aiter_moe(
w2_scale=w2_scale, hidden_states=hidden_states,
w1_zp=w1_zp, w1=w1,
w2_zp=w2_zp, w2=w2,
a1_scale=a1_scale, topk_weights=topk_weights,
a2_scale=a2_scale, topk_ids=topk_ids,
block_shape=block_shape, moe_config=moe_config,
use_nn_moe=False, inplace=inplace,
routed_scaling_factor=routed_scaling_factor, activation=activation,
shared_output=shared_output, w1_scale=w1_scale,
i_q=i_q, w2_scale=w2_scale,
i_s=i_s w1_zp=w1_zp,
) w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=None,
global_num_experts=global_num_experts,
expert_map=expert_map,
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
i_q=i_q,
i_s=i_s
)
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -26,6 +26,12 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEPrepareAndFinalize, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
import aiter
from aiter.test_common import checkAllclose, perftest
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.fused_moe import fused_topk, torch_moe
from aiter import dtypes, ActivationType
from aiter.moe import get_aiter_moe_config, aiter_moe, MoeSolutionType, MoeQuantType
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin
...@@ -169,23 +175,45 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -169,23 +175,45 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def shuffle_w8a8_gemm1(self, weight_data):
w_fp8 = weight_data.to(torch.float8_e4m3fn)
shuffled = moe_layout_shuffle_gemm1(w_fp8)
return shuffled.view(torch.int8)
def shuffle_w8a8_gemm2(self, weight_data):
w_fp8 = weight_data.to(torch.float8_e4m3fn)
shuffled = moe_layout_shuffle_gemm2(w_fp8)
return shuffled.view(torch.int8)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] if envs.VLLM_USE_AITER_MOE_W8A8==True:
for ii in range(layer.w13_weight.shape[0]): layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, requires_grad=False)
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, requires_grad=False)
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
del w1_marlin_list shuffled_w13 = self.shuffle_w8a8_gemm1(layer.w13_weight)
w2_marlin_list = [] w13_data = shuffled_w13.view(*layer.w13_weight.shape).view(torch.int8)
for ii in range(layer.w2_weight.shape[0]): layer.w13_weight = Parameter(w13_data, requires_grad=False)
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in) shuffled_w2 = self.shuffle_w8a8_gemm2(layer.w2_weight)
w2_marlin = torch.stack(w2_marlin_list, dim=0) w2_data = shuffled_w2.view(*layer.w2_weight.shape).view(torch.int8)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin) layer.w2_weight = Parameter(w2_data, requires_grad=False)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) else:
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def fused_moe_forward( def fused_moe_forward(
self, self,
...@@ -200,27 +228,66 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): ...@@ -200,27 +228,66 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
): ):
if envs.VLLM_USE_AITER_MOE_W8A8==True:
return fused_experts_impl_fp8_marlin( m_flat = x.view(-1, x.shape[-1])
hidden_states=x, M = m_flat.shape[0]
w1=layer.w13_weight, E = layer.w13_weight.size(0)
w2=layer.w2_weight, K = x.size(-1)
topk_weights=topk_weights, N1 = layer.w13_weight.size(1)
topk_ids=topk_ids, topk = topk_ids.size(1)
inplace=True, w1_input = layer.w13_weight.view(E, N1, K)
activation=activation, w2_input = layer.w2_weight.view(E, K, N1 // 2)
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, _, moe_cfg = get_aiter_moe_config(
per_channel_quant=True, M=M,
global_num_experts=global_num_experts, E=E,
expert_map=expert_map, N1=N1,
w1_scale=layer.w13_weight_scale, N2=N1 // 2,
w2_scale=layer.w2_weight_scale, K=K,
a1_scale=layer.w13_input_scale, top_k=topk,
a2_scale=layer.w2_input_scale, block_size=0,
use_nn_moe=False, dtype=x.dtype,
shared_output=shared_output, quant_type=MoeQuantType.W8A8,
routed_scaling_factor=routed_scaling_factor) )
output = aiter_moe(
hidden_states=x,
w1=w1_input,
w2=w2_input,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
inplace=False,
activation=getattr(layer, "activation", "silu"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=getattr(layer, "w13_input_scale", None),
a2_scale=getattr(layer, "w2_input_scale", None),
global_num_experts=E,
expert_map=getattr(layer, "expert_map", None),
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_fp8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def apply( def apply(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment