import json from typing import Dict, List, TypedDict import torch from transformers import AutoConfig from sglang.srt.layers.moe.fused_moe_triton.fused_moe import get_config_dtype_str from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import ( get_config_file_name, ) from sglang.srt.utils import is_hip class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int BLOCK_SIZE_N: int BLOCK_SIZE_K: int GROUP_SIZE_M: int num_warps: int num_stages: int def calculate_shard_intermediate_size( intermediate_size: int, tp_size: int, ep_size: int = 1 ) -> int: assert tp_size % ep_size == 0 moe_tp_size = tp_size // ep_size assert intermediate_size % moe_tp_size == 0 return 2 * intermediate_size // moe_tp_size def get_model_config( model_name: str, tp_size: int, ep_size: int = 1, disable_shared_experts_fusion: bool = False, topk_ids_dir: str = None, ) -> Dict: config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) block_shape = None if ( hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config ): block_shape = config.quantization_config["weight_block_size"] assert len(block_shape) == 2 architecture = config.architectures[0] # Replace config with text_config for encoder-decoder models after getting block_shape and architecture if hasattr(config, "text_config"): config = config.get_text_config() if architecture == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts // ep_size topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size elif architecture == "JambaForCausalLM": E = config.num_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.intermediate_size elif architecture in [ "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM", "Qwen3NextForCausalLM", "Qwen3VLMoeForConditionalGeneration", ]: E = config.num_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: E = (config.n_routed_experts // ep_size) + ( 0 if disable_shared_experts_fusion or architecture not in ["DeepseekV3ForCausalLM"] else 1 ) topk = config.num_experts_per_tok + ( 0 if disable_shared_experts_fusion or topk_ids_dir is None else 1 ) intermediate_size = config.moe_intermediate_size elif architecture == "Llama4ForConditionalGeneration": E = config.num_local_experts // ep_size + ( 0 if disable_shared_experts_fusion else 1 ) topk = config.num_experts_per_tok + ( 0 if disable_shared_experts_fusion or topk_ids_dir is None else 1 ) intermediate_size = config.intermediate_size elif architecture in [ "Grok1ForCausalLM", "Grok1ImgGen", "Grok1AForCausalLM", ]: E = config.num_local_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size elif architecture in [ "BailingMoEForCausalLM", "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", ]: E = config.num_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size elif architecture in ["Glm4MoeForCausalLM"]: E = config.n_routed_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size else: # Default: Mixtral E = config.num_local_experts // ep_size topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = calculate_shard_intermediate_size( intermediate_size, tp_size, ep_size ) return { "num_experts": E, "topk": topk, "hidden_size": config.hidden_size, "shard_intermediate_size": shard_intermediate_size, "dtype": config.torch_dtype, "block_shape": block_shape, "architecture": architecture, } def get_rocm_configs_compute_bound() -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] waves_per_eu_range = 0 for num_stages in [2]: for block_m in [32, 64, 128, 256]: for block_k in [32, 64, 128, 256]: for block_n in [16, 32, 64, 128, 256]: for num_warps in [1, 2, 4, 8]: for group_size in [1, 4, 8, 16, 32]: configs.append( { "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": block_k, "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, "waves_per_eu": waves_per_eu_range, } ) return configs def get_configs_compute_bound() -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] if is_hip(): configs = get_rocm_configs_compute_bound() else: for num_stages in [2, 3, 4, 5]: for block_m in [16, 32, 64, 128, 256]: for block_k in [64, 128, 256]: for block_n in [32, 64, 128, 256]: for num_warps in [4, 8]: for group_size in [1, 16, 32, 64]: configs.append( { "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": block_k, "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, } ) return configs def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], "GROUP_SIZE_M": config["GROUP_SIZE_M"], "num_warps": config["num_warps"], "num_stages": config["num_stages"], **( {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} ), **({"USE_TMA": config["USE_TMA"]} if "USE_TMA" in config else {}), } def save_configs( configs: Dict[int, BenchmarkConfig], filename: str, ) -> None: print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) f.write("\n") def get_config_filename( num_experts: int, shard_intermediate_size: int, hidden_size: int, topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, per_channel_quant: bool, block_shape: List[int], ) -> str: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name( num_experts, shard_intermediate_size // 2, dtype_str, block_shape, per_channel_quant, ) return filename def get_default_batch_sizes() -> List[int]: return [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096, ]