Unverified Commit 1ed1abfd authored by Chen1022's avatar Chen1022 Committed by GitHub
Browse files

feat: add EP support in tuning (#12012)

parent ecb9fa14
......@@ -421,69 +421,88 @@ def main(args: argparse.Namespace):
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
E = config.ffn_config.moe_num_experts // args.ep_size
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in [
"Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
]:
E = config.num_experts
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = (
config.n_routed_experts + (0 if args.disable_shared_experts_fusion else 1)
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
E = (config.n_routed_experts // args.ep_size) + (
0
if args.disable_shared_experts_fusion
or config.architectures[0] not in ["DeepseekV3ForCausalLM"]
else 1
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] == "Llama4ForConditionalGeneration":
E = config.text_config.num_local_experts + (
E = config.text_config.num_local_experts // args.ep_size + (
0 if args.disable_shared_experts_fusion else 1
)
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
"Grok1AForCausalLM",
]:
E = config.num_local_experts
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in [
"BailingMoEForCausalLM",
"BailingMoeForCausalLM",
"BailingMoeV2ForCausalLM",
]:
E = config.num_experts
E = config.num_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
elif config.architectures[0] in ["Glm4MoeForCausalLM"]:
E = config.n_routed_experts
E = config.n_routed_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
else:
# Default: Mixtral
E = config.num_local_experts
E = config.num_local_experts // args.ep_size
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
shard_intermediate_size = (
2 * intermediate_size // (args.tp_size // args.ep_size)
)
hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype
......@@ -626,6 +645,7 @@ if __name__ == "__main__":
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument("--tp-size", "--tp", type=int, default=2)
parser.add_argument("--ep-size", "--ep", type=int, default=1)
parser.add_argument(
"--dtype",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment