Unverified Commit c7867b67 authored by Mohammad Miadh Angkad's avatar Mohammad Miadh Angkad Committed by GitHub
Browse files

[Fix] Add per_channel_quant parameter to MoE config functions (#11201)

parent 516738b0
......@@ -16,14 +16,19 @@ _is_hip = is_hip()
def get_config_file_name(
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[int] = None,
per_channel_quant: bool = False,
) -> str:
device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
)
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
@functools.lru_cache
......@@ -33,6 +38,7 @@ def get_moe_configs(
dtype: Optional[str],
block_n: Optional[int] = 0,
block_k: Optional[int] = 0,
per_channel_quant: bool = False,
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -47,7 +53,9 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
json_file_name = get_config_file_name(
E, N, dtype, [block_n, block_k], per_channel_quant
)
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment