Unverified Commit 82cfcd3b authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] tuning_fused_moe for MLLM and small refactor (#11224)


Co-authored-by: default avatarCursor Agent <cursoragent@cursor.com>
parent 6c1a3f0c
...@@ -419,54 +419,73 @@ def get_filename( ...@@ -419,54 +419,73 @@ def get_filename(
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
def _calculate_shard_intermediate_size(intermediate_size: int) -> int:
# In EP mode, use original intermediate_size; otherwise apply TP sharding
return (
intermediate_size
if args.ep_size > 1
else 2 * intermediate_size // args.tp_size
)
# Check EP mode constraint: tp_size must be 1 when ep_size > 1
if args.ep_size > 1 and args.tp_size != 1:
raise ValueError(
f"When using Expert Parallelism (ep_size={args.ep_size}), "
f"tp_size must be set to 1, but got tp_size={args.tp_size}. "
f"Please set --tp-size 1 when using --ep-size > 1."
)
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts // args.ep_size # Determine block shape for quantization
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
architecture = config.architectures[0]
# replace config with text_config for encoder-decoder models after getting block_shape and architecture
if hasattr(config, "text_config"):
config = config.get_text_config()
if architecture == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture == "JambaForCausalLM":
) E = config.num_experts
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture in [
)
elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM", "Qwen3NextForCausalLM",
"Qwen3VLMoeForConditionalGeneration",
]: ]:
E = config.num_experts // args.ep_size E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
) E = (
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
E = (config.n_routed_experts // args.ep_size) + ( if architecture == "DeepseekV3ForCausalLM"
0 else config.n_routed_experts
if args.disable_shared_experts_fusion
or config.architectures[0] not in ["DeepseekV3ForCausalLM"]
else 1
) )
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture == "Llama4ForConditionalGeneration":
) E = config.num_local_experts + (0 if args.disable_shared_experts_fusion else 1)
elif config.architectures[0] == "Llama4ForConditionalGeneration": topk = config.num_experts_per_tok
E = config.text_config.num_local_experts // args.ep_size + ( intermediate_size = config.intermediate_size
0 if args.disable_shared_experts_fusion else 1 shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
) elif architecture in [
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in [
"Grok1ForCausalLM", "Grok1ForCausalLM",
"Grok1ImgGen", "Grok1ImgGen",
"Grok1AForCausalLM", "Grok1AForCausalLM",
...@@ -474,10 +493,8 @@ def main(args: argparse.Namespace): ...@@ -474,10 +493,8 @@ def main(args: argparse.Namespace):
E = config.num_local_experts // args.ep_size E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture in [
)
elif config.architectures[0] in [
"BailingMoEForCausalLM", "BailingMoEForCausalLM",
"BailingMoeForCausalLM", "BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM", "BailingMoeV2ForCausalLM",
...@@ -485,38 +502,25 @@ def main(args: argparse.Namespace): ...@@ -485,38 +502,25 @@ def main(args: argparse.Namespace):
E = config.num_experts // args.ep_size E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size) elif architecture in ["Glm4MoeForCausalLM"]:
) E = config.n_routed_experts
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size)
)
else: else:
# Default: Mixtral # Default: Mixtral
E = config.num_local_experts // args.ep_size E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = ( shard_intermediate_size = _calculate_shard_intermediate_size(intermediate_size)
2 * intermediate_size // (args.tp_size // args.ep_size)
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size hidden_size = config.hidden_size
dtype = config.torch_dtype dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
per_channel_quant = args.per_channel_quant per_channel_quant = args.per_channel_quant
block_shape = None
if (
hasattr(config, "quantization_config")
and "weight_block_size" in config.quantization_config
):
block_shape = config.quantization_config["weight_block_size"]
assert len(block_shape) == 2
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment