Unverified Commit ef8ec07b authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support tuning moe for llama 4 model (#6042)

parent f24fc5b8
...@@ -408,6 +408,12 @@ def main(args: argparse.Namespace): ...@@ -408,6 +408,12 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
n_share_fusion_experts = args.n_share_experts_fusion
E = config.text_config.num_local_experts + n_share_fusion_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in [ elif config.architectures[0] in [
"Grok1ForCausalLM", "Grok1ForCausalLM",
"Grok1ImgGen", "Grok1ImgGen",
...@@ -424,7 +430,7 @@ def main(args: argparse.Namespace): ...@@ -424,7 +430,7 @@ def main(args: argparse.Namespace):
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = getattr(config, "hidden_size", None) or config.text_config.hidden_size
dtype = config.torch_dtype dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a8 = args.dtype == "int8_w8a8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment