Unverified Commit 28c94770 authored by roikoren755's avatar roikoren755 Committed by GitHub
Browse files

[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)


Signed-off-by: default avatarRoi Koren <roik@nvidia.com>
parent af8fd730
...@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module): ...@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
) )
if self.use_latent_moe: if self.use_latent_moe:
# TODO: check if using ReplicatedLinear is better than self.fc1_latent_proj = ReplicatedLinear(
# ColumnParallelLinear + all_gather
self.fc1_latent_proj = ColumnParallelLinear(
input_size=config.hidden_size, input_size=config.hidden_size,
output_size=self.moe_hidden_size, output_size=self.moe_hidden_size,
bias=config.mlp_bias, bias=config.mlp_bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=self.is_sequence_parallel, disable_tp=self.is_sequence_parallel,
# We need to gather the output to prepare input for moe
gather_output=True,
prefix=f"{prefix}.fc1_latent_proj", prefix=f"{prefix}.fc1_latent_proj",
) )
self.fc2_latent_proj = ReplicatedLinear( self.fc2_latent_proj = ReplicatedLinear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment