Commit f5a7f12c authored by gaoqiong's avatar gaoqiong
Browse files

增加fused moe文件中w4a8的相关修改

parent 7e5fb6fe
......@@ -1206,7 +1206,8 @@ def get_config_dtype_str(
use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False) -> Optional[str]:
use_int8_w8a8: Optional[bool] = False,
use_int4_w4a8: Optional[bool] = False) -> Optional[str]:
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a8:
......@@ -1215,7 +1216,7 @@ def get_config_dtype_str(
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_int4_w4a16:
elif use_int4_w4a8:
return "int4_w4a8"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
......@@ -1961,6 +1962,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8,
dtype=hidden_states.dtype)
config = try_get_optimal_moe_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment