Commit 7017f30c authored by gaoqiong's avatar gaoqiong
Browse files

修改W4A8 以及W8A8量化量化092接口

parent 98958aed
...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
...@@ -653,6 +653,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -653,6 +653,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
...@@ -1214,6 +1215,8 @@ def get_config_dtype_str( ...@@ -1214,6 +1215,8 @@ def get_config_dtype_str(
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
return "int4_w4a16" return "int4_w4a16"
elif use_int4_w4a16:
return "int4_w4a8"
elif dtype == torch.float: elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE # avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs # use fp16/bfloat16 configs
...@@ -1232,6 +1235,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1232,6 +1235,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1245,7 +1249,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1245,7 +1249,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1263,6 +1267,7 @@ def inplace_fused_experts_fake( ...@@ -1263,6 +1267,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1298,6 +1303,7 @@ def outplace_fused_experts( ...@@ -1298,6 +1303,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1312,7 +1318,7 @@ def outplace_fused_experts( ...@@ -1312,7 +1318,7 @@ def outplace_fused_experts(
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, per_channel_quant, use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe)
...@@ -1329,6 +1335,7 @@ def outplace_fused_experts_fake( ...@@ -1329,6 +1335,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1383,6 +1390,7 @@ def fused_experts( ...@@ -1383,6 +1390,7 @@ def fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1442,6 +1450,7 @@ def fused_experts( ...@@ -1442,6 +1450,7 @@ def fused_experts(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -1468,6 +1477,7 @@ def fused_experts_impl( ...@@ -1468,6 +1477,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1506,6 +1516,34 @@ def fused_experts_impl( ...@@ -1506,6 +1516,34 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False
) )
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False,
use_int8_w8a8= False,
use_int8_w8a16= False,
use_int4_w4a16 = False,
use_int4_w4a8 = True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe= False
)
#
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), ( assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch") "Hidden size mismatch")
...@@ -1542,12 +1580,14 @@ def fused_experts_impl( ...@@ -1542,12 +1580,14 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16) use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8)
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
...@@ -1648,6 +1688,7 @@ def fused_experts_impl( ...@@ -1648,6 +1688,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1687,6 +1728,7 @@ def fused_experts_impl( ...@@ -1687,6 +1728,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1714,6 +1756,7 @@ def fused_moe( ...@@ -1714,6 +1756,7 @@ def fused_moe(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
...@@ -1799,6 +1842,7 @@ def fused_moe( ...@@ -1799,6 +1842,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -1820,6 +1864,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1820,6 +1864,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
...@@ -1829,6 +1874,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1829,6 +1874,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
)) ))
...@@ -1837,6 +1883,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1837,6 +1883,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16 self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16 self.use_int8_w8a16 = use_int8_w8a16
self.use_int4_w4a8= use_int4_w4a8
@property @property
def activation_formats( def activation_formats(
...@@ -1966,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1966,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -1996,6 +2044,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1996,6 +2044,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -2005,6 +2054,7 @@ def modular_triton_fused_moe( ...@@ -2005,6 +2054,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8:bool,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
) -> mk.FusedMoEModularKernel: ) -> mk.FusedMoEModularKernel:
...@@ -2015,6 +2065,7 @@ def modular_triton_fused_moe( ...@@ -2015,6 +2065,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8= use_int4_w4a8,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
), ),
......
...@@ -477,7 +477,7 @@ class BlockInt8MoEMethod: ...@@ -477,7 +477,7 @@ class BlockInt8MoEMethod:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.") "EPLB not supported for `MoeBlockInt8Method` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -974,147 +974,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -974,147 +974,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
) )
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
...@@ -1729,12 +1588,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1729,12 +1588,17 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Int8Method` yet.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod: ...@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod:
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod: ...@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod:
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
...@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod: ...@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod:
E=layer.w13_weight.shape[0] E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1] N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1] N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2] K=N1//2
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK) json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup #warmup
...@@ -345,12 +345,16 @@ class W8A8Int8MoEMethod: ...@@ -345,12 +345,16 @@ class W8A8Int8MoEMethod:
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `W8A8Int8MoeMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -374,7 +378,7 @@ class W8A8Int8MoEMethod: ...@@ -374,7 +378,7 @@ class W8A8Int8MoEMethod:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int8_w8a8=True, use_int4_w4a8=True,
per_channel_quant=True, per_channel_quant=True,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
......
...@@ -2060,7 +2060,13 @@ class W8a8GetCacheJSON: ...@@ -2060,7 +2060,13 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json" return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK, def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None): block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None: if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment