Unverified Commit d9dd5298 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

enable DeepSeek V3 shared_experts_fusion in sm90 (#5571)

parent 0a0dd34e
......@@ -1427,6 +1427,18 @@ class DeepseekV2ForCausalLM(nn.Module):
assert (
self.n_share_experts_fusion == self.tp_size
), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
elif self.n_share_experts_fusion == 0:
if (
torch.cuda.get_device_capability("cuda") >= (9, 0)
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
and self.config.n_routed_experts == 256
and (not global_server_args_dict["enable_deepep_moe"])
):
self.n_share_experts_fusion = self.tp_size
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
logger.info(
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment