Unverified Commit 28c94770 authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent af8fd730
......@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
)
if self.use_latent_moe:
# TODO: check if using ReplicatedLinear is better than
# ColumnParallelLinear + all_gather
self.fc1_latent_proj = ColumnParallelLinear(
self.fc1_latent_proj = ReplicatedLinear(
input_size=config.hidden_size,
output_size=self.moe_hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
# We need to gather the output to prepare input for moe
gather_output=True,
prefix=f"{prefix}.fc1_latent_proj",
)
self.fc2_latent_proj = ReplicatedLinear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment