Unverified Commit 77f8001f authored by tomeras91's avatar tomeras91 Committed by GitHub
Browse files

[Model][Bugfix] fix pipeline parallelism support for NemotronH (#27968)


Signed-off-by: default avatarTomer Asida <57313761+tomeras91@users.noreply.github.com>
parent 300a2659
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import typing import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice
import torch import torch
from torch import nn from torch import nn
...@@ -549,7 +550,7 @@ class NemotronHModel(nn.Module): ...@@ -549,7 +550,7 @@ class NemotronHModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers" len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
) )
self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size ["hidden_states", "residual"], config.hidden_size
) )
...@@ -564,7 +565,7 @@ class NemotronHModel(nn.Module): ...@@ -564,7 +565,7 @@ class NemotronHModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -576,8 +577,7 @@ class NemotronHModel(nn.Module): ...@@ -576,8 +577,7 @@ class NemotronHModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
residual = None for layer in islice(self.layers, self.start_layer, self.end_layer):
for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -633,6 +633,9 @@ class NemotronHModel(nn.Module): ...@@ -633,6 +633,9 @@ class NemotronHModel(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -678,6 +681,9 @@ class NemotronHModel(nn.Module): ...@@ -678,6 +681,9 @@ class NemotronHModel(nn.Module):
if is_expert_weight: if is_expert_weight:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
...@@ -792,7 +798,9 @@ class NemotronHForCausalLM( ...@@ -792,7 +798,9 @@ class NemotronHForCausalLM(
self.unpadded_vocab_size, config.vocab_size self.unpadded_vocab_size, config.vocab_size
) )
self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# Set MoE hyperparameters # Set MoE hyperparameters
if self.model.has_moe: if self.model.has_moe:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment