"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "7ad7adb67f1350b6e9f7cfdd7aacf38eed093bb1"
Unverified Commit 1c1bb0bb authored by Divakar Verma's avatar Divakar Verma Committed by GitHub
Browse files

[Misc][MoE] add Deepseek-V3 moe tuning support (#12558)


Signed-off-by: default avatarDivakar Verma <divakar.verma@amd.com>
parent e0cc5f25
...@@ -450,7 +450,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, ...@@ -450,7 +450,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
config = AutoConfig.from_pretrained(args.model) config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
...@@ -461,6 +462,11 @@ def main(args: argparse.Namespace): ...@@ -461,6 +462,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
...@@ -538,6 +544,7 @@ if __name__ == "__main__": ...@@ -538,6 +544,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment