Commit d49bafc5 authored by lixh6's avatar lixh6
Browse files

[FEATURE] 接入Aiter MoE W8A8-Int8 量化模型支持

parent 753b29c0
...@@ -436,28 +436,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -436,28 +436,44 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def shuffle_w8a8_gemm1(self, weight_data):
w_i8 = weight_data.to(torch.int8)
return moe_layout_shuffle_gemm1(w_i8)
def shuffle_w8a8_gemm2(self, weight_data):
w_i8 = weight_data.to(torch.int8)
return moe_layout_shuffle_gemm2(w_i8)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] if envs.VLLM_USE_AITER_MOE_W8A8==True:
for ii in range(layer.w13_weight.shape[0]): layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, requires_grad=False)
if not self.use_deepep: layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, requires_grad=False)
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) shuffled_w13 = self.shuffle_w8a8_gemm1(layer.w13_weight)
else: layer.w13_weight = Parameter(shuffled_w13.view(*layer.w13_weight.shape), requires_grad=False)
w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii]) shuffled_w2 = self.shuffle_w8a8_gemm2(layer.w2_weight)
w1_marlin_list.append(w1_marlin_in) layer.w2_weight = Parameter(shuffled_w2.view(*layer.w2_weight.shape), requires_grad=False)
w1_marlin = torch.stack(w1_marlin_list, dim=0) else:
w1_marlin_list = []
del w1_marlin_list for ii in range(layer.w13_weight.shape[0]):
w2_marlin_list = [] if not self.use_deepep:
for ii in range(layer.w2_weight.shape[0]): w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
if not self.use_deepep: else:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
else: w1_marlin_list.append(w1_marlin_in)
w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii]) w1_marlin = torch.stack(w1_marlin_list, dim=0)
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) del w1_marlin_list
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
...@@ -473,30 +489,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -473,30 +489,70 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts_impl_int8_marlin( if envs.VLLM_USE_AITER_MOE_W8A8==True:
hidden_states=x, m_flat = x.view(-1, x.shape[-1])
w1=layer.w13_weight, M = m_flat.shape[0]
w2=layer.w2_weight, E = layer.w13_weight.size(0)
topk_weights=topk_weights, K = x.size(-1)
topk_ids=topk_ids, N1 = layer.w13_weight.size(1)
inplace=True, topk = topk_ids.size(1)
activation=layer.activation, w1_input = layer.w13_weight.view(E, N1, K)
apply_router_weight_on_input=layer.apply_router_weight_on_input, w2_input = layer.w2_weight.view(E, K, N1 // 2)
use_int8_w8a8=True,
per_channel_quant=True, _, moe_cfg = get_aiter_moe_config(
global_num_experts=layer.global_num_experts, M=M,
expert_map=layer.expert_map, E=E,
quant_config=self.moe_quant_config, N1=N1,
w1_scale=layer.w13_weight_scale, N2=N1 // 2,
w2_scale=layer.w2_weight_scale, K=K,
a1_scale=layer.w13_input_scale, top_k=topk,
a2_scale=layer.w2_input_scale, block_size=0,
use_nn_moe=False, dtype=x.dtype,
i_q=i_q, quant_type=MoeQuantType.W8A8,
i_s=i_s, )
shared_output=shared_output, output = aiter_moe(
routed_scaling_factor=routed_scaling_factor, hidden_states=x,
) w1=w1_input,
w2=w2_input,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_config=moe_cfg,
inplace=False,
activation=getattr(layer, "activation", "silu"),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=getattr(layer, "w13_input_scale", None),
a2_scale=getattr(layer, "w2_input_scale", None),
global_num_experts=E,
expert_map=getattr(layer, "expert_map", None),
routed_scaling_factor=routed_scaling_factor,
)
return output
else:
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
i_q=i_q,
i_s=i_s,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl( def select_gemm_impl(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment