Unverified Commit a979daac authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Fallback to lower triton version for unfound fused moe configs (#7013)

parent f1569876
...@@ -983,6 +983,8 @@ def get_moe_configs( ...@@ -983,6 +983,8 @@ def get_moe_configs(
kernel on a given batch size bs, the closest batch size in the grid should kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel. be picked and the associated configuration chosen to invoke the kernel.
""" """
# Supported Triton versions, should be sorted from the newest to the oldest
supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
...@@ -1005,12 +1007,28 @@ def get_moe_configs( ...@@ -1005,12 +1007,28 @@ def get_moe_configs(
# For example, updating the Triton version might cause all old configs to become suboptimal. # For example, updating the Triton version might cause all old configs to become suboptimal.
# To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment. # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
# For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
log_info_on_rank0( logger.info(f"Using MoE kernel config from {config_file_path}.")
logger, f"Using MoE kernel config from {config_file_path}."
)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
# Searching for other triton versions that supports the same config
for try_triton_version in supported_triton_versions:
if try_triton_version == triton_version:
continue
try_config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
f"triton_{try_triton_version.replace('.', '_')}",
json_file_name,
)
if os.path.exists(try_config_file_path):
with open(try_config_file_path) as f:
logger.warning(
f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!",
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default # If no optimized configuration is available, we will use the default
# configuration # configuration
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment