"vllm/vscode:/vscode.git/clone" did not exist on "2c8b9182b5ced00d83bed15ef8bc0ac6e079b6ee"
Commit 5ad884ee authored by zhuwenwen's avatar zhuwenwen
Browse files

去除多余的w4a8参数

增加fused moe文件中w4a8的相关修改
fix: 修复W8A8读config路径错误,删除int8_utils.py文件
fix: 修复W8A8INT8读config问题
修改W4A8 以及W8A8量化量化092接口
parent 84dfdb17
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
w8a8_block_int8_matmul) w8a8_block_int8_matmul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8) per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
......
...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
...@@ -658,6 +659,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -658,6 +659,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False) -> None:
...@@ -1211,7 +1213,8 @@ def get_config_dtype_str( ...@@ -1211,7 +1213,8 @@ def get_config_dtype_str(
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False,
use_mxfp4_w4a4: Optional[bool] = False, use_mxfp4_w4a4: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False) -> Optional[str]: use_int8_w8a8: Optional[bool] = False,
use_int4_w4a8: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8: if use_fp8_w8a8:
return "fp8_w8a8" return "fp8_w8a8"
elif use_int8_w8a8: elif use_int8_w8a8:
...@@ -1220,6 +1223,8 @@ def get_config_dtype_str( ...@@ -1220,6 +1223,8 @@ def get_config_dtype_str(
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
return "int4_w4a16" return "int4_w4a16"
elif use_int4_w4a8:
return "int4_w4a8"
elif use_mxfp4_w4a4: elif use_mxfp4_w4a4:
return "mxfp4_w4a4" return "mxfp4_w4a4"
elif dtype == torch.float: elif dtype == torch.float:
...@@ -1240,6 +1245,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1240,6 +1245,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1254,7 +1260,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1254,7 +1260,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_int4_w4a8,
use_mxfp4_w4a4, per_channel_quant, global_num_experts, use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, use_nn_moe) a2_scale, block_shape, use_nn_moe)
...@@ -1272,6 +1278,7 @@ def inplace_fused_experts_fake( ...@@ -1272,6 +1278,7 @@ def inplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1407,6 +1414,7 @@ def outplace_fused_experts( ...@@ -1407,6 +1414,7 @@ def outplace_fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1422,7 +1430,7 @@ def outplace_fused_experts( ...@@ -1422,7 +1430,7 @@ def outplace_fused_experts(
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, use_mxfp4_w4a4, use_int4_w4a16, use_int4_w4a8, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, use_nn_moe) a1_scale, a2_scale, block_shape, use_nn_moe)
...@@ -1439,6 +1447,7 @@ def outplace_fused_experts_fake( ...@@ -1439,6 +1447,7 @@ def outplace_fused_experts_fake(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1495,6 +1504,7 @@ def fused_experts( ...@@ -1495,6 +1504,7 @@ def fused_experts(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1562,6 +1572,7 @@ def fused_experts( ...@@ -1562,6 +1572,7 @@ def fused_experts(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -1589,6 +1600,7 @@ def fused_experts_impl( ...@@ -1589,6 +1600,7 @@ def fused_experts_impl(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1628,6 +1640,33 @@ def fused_experts_impl( ...@@ -1628,6 +1640,33 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False
) )
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False,
use_int8_w8a8= False,
use_int8_w8a16= False,
use_int4_w4a16 = False,
use_int4_w4a8 = True,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe= False
)
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), ( assert hidden_states.size(1) // 2 == w1.size(2), (
"Hidden size mismatch") "Hidden size mismatch")
...@@ -1667,6 +1706,7 @@ def fused_experts_impl( ...@@ -1667,6 +1706,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
...@@ -1781,6 +1821,7 @@ def fused_experts_impl( ...@@ -1781,6 +1821,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1820,6 +1861,7 @@ def fused_experts_impl( ...@@ -1820,6 +1861,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
...@@ -1847,6 +1889,7 @@ def fused_moe( ...@@ -1847,6 +1889,7 @@ def fused_moe(
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool =False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
...@@ -1936,6 +1979,7 @@ def fused_moe( ...@@ -1936,6 +1979,7 @@ def fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -1958,6 +2002,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1958,6 +2002,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -1968,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1968,6 +2013,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
...@@ -1977,6 +2023,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -1977,6 +2023,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16 self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16 self.use_int8_w8a16 = use_int8_w8a16
self.use_int4_w4a8 = use_int4_w4a8
self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.use_mxfp4_w4a4 = use_mxfp4_w4a4
@property @property
...@@ -2062,6 +2109,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2062,6 +2109,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8,
use_mxfp4_w4a4=self.use_mxfp4_w4a4, use_mxfp4_w4a4=self.use_mxfp4_w4a4,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
...@@ -2117,6 +2165,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2117,6 +2165,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -2147,6 +2196,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2147,6 +2196,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape) block_shape=self.block_shape)
...@@ -2158,6 +2208,7 @@ def modular_triton_fused_moe( ...@@ -2158,6 +2208,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8:bool,
use_mxfp4_w4a4: bool, use_mxfp4_w4a4: bool,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -2169,6 +2220,7 @@ def modular_triton_fused_moe( ...@@ -2169,6 +2220,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8= use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4) quant_dequant_mxfp4)
......
...@@ -477,7 +477,7 @@ class BlockInt8MoEMethod: ...@@ -477,7 +477,7 @@ class BlockInt8MoEMethod:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.") "EPLB not supported for `MoeBlockInt8Method` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -1018,147 +1018,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -1018,147 +1018,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
) )
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
...@@ -1419,11 +1278,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): ...@@ -1419,11 +1278,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
"`CompressedTensorsWNA16MarlinMoEMethod` yet.") "`CompressedTensorsW8A8Int8Method` yet.")
assert activation == "silu", ( assert activation == "silu", (
f"{activation} not supported for Marlin MoE.") f"{activation} not supported for Marlin MoE.")
...@@ -1701,6 +1561,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1701,6 +1561,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
f"{self.weight_quant}, {self.input_quant}") f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic self.static_input_scales = not self.input_quant.dynamic
self.tritonsingleton = W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -1756,6 +1617,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1756,6 +1617,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
#self.tritonsingleton.gen_model_json(block_size)
return
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/sgl-project/sglang/blob/4cb53ecd0cffceb6dee5c011a58f65997a86f151/python/sglang/srt/layers/quantization/int8_kernel.py
import functools
import json
import logging
import os
from typing import Any, Optional
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = logging.getLogger(__name__)
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: list[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor,
dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to int8 values with
tensor-wise quantization."""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: list[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
if current_platform.is_rocm():
from triton.language import core
# NOTE: This can be removed when hip.libdevice.round() is available.
@core.extern
def round_f32(arg0, _builder=None):
return core.extern_elementwise("",
"", [arg0], {
(core.dtype("fp32"), ):
("llvm.round", core.dtype("fp32")),
(core.dtype("fp64"), ):
("llvm.round", core.dtype("fp64")),
},
is_pure=True,
_builder=_builder)
@triton.jit
def round_int8(x):
return round_f32(x).to(tl.int8)
else:
@triton.jit
def round_int8(x):
return tl.extra.cuda.libdevice.round(x).to(tl.int8)
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = round_int8(x_q)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Columns of input
N,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size, ),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M, )](
x,
x_q,
x_s,
group_size,
N,
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:,
None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
("Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"),
config_file_path,
)
return None
def w8a8_block_int8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be
2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
...@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod: ...@@ -264,7 +264,7 @@ class W8A8Int8MoEMethod:
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
), ),
requires_grad=False, requires_grad=False,
) )
...@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod: ...@@ -272,7 +272,7 @@ class W8A8Int8MoEMethod:
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
...@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod: ...@@ -306,13 +306,13 @@ class W8A8Int8MoEMethod:
E=layer.w13_weight.shape[0] E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1] N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1] N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2] K=N1//2
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK) json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup #warmup
...@@ -345,12 +345,17 @@ class W8A8Int8MoEMethod: ...@@ -345,12 +345,17 @@ class W8A8Int8MoEMethod:
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `W8A8Int8MoeMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -374,7 +379,7 @@ class W8A8Int8MoEMethod: ...@@ -374,7 +379,7 @@ class W8A8Int8MoEMethod:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int8_w8a8=True, use_int4_w4a8=True,
per_channel_quant=True, per_channel_quant=True,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
......
...@@ -2184,7 +2184,7 @@ class W8a8GetCacheJSON: ...@@ -2184,7 +2184,7 @@ class W8a8GetCacheJSON:
def _initialize(self): def _initialize(self):
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_folder_path = os.path.dirname(os.path.abspath(__file__)) current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../lmslim/configs/w8a8' json_folder_path = current_folder_path+'/../../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path)) self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict={} self.triton_json_dict={}
...@@ -2295,12 +2295,19 @@ class W8a8GetCacheJSON: ...@@ -2295,12 +2295,19 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json" return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK, def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None): block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None: if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else: else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK): def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path cache_json_file=file_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment