Unverified Commit c7867b67 authored by Mohammad Miadh Angkad's avatar Mohammad Miadh Angkad Committed by GitHub
Browse files

[Fix] Add per_channel_quant parameter to MoE config functions (#11201)

parent 516738b0
...@@ -16,14 +16,19 @@ _is_hip = is_hip() ...@@ -16,14 +16,19 @@ _is_hip = is_hip()
def get_config_file_name( def get_config_file_name(
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[int] = None,
per_channel_quant: bool = False,
) -> str: ) -> str:
device_name = get_device_name().replace(" ", "_") device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ( block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
) )
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
@functools.lru_cache @functools.lru_cache
...@@ -33,6 +38,7 @@ def get_moe_configs( ...@@ -33,6 +38,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = 0, block_n: Optional[int] = 0,
block_k: Optional[int] = 0, block_k: Optional[int] = 0,
per_channel_quant: bool = False,
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -47,7 +53,9 @@ def get_moe_configs( ...@@ -47,7 +53,9 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k]) json_file_name = get_config_file_name(
E, N, dtype, [block_n, block_k], per_channel_quant
)
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance. # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment