Unverified Commit 83fb2d09 authored by danielafrimi's avatar danielafrimi Committed by GitHub
Browse files

Support heterogeneous NemotronHPuzzle model (#32549)



Signed-off-by: <dafrimi@nvidia.com>
Signed-off-by: default avatarDaniel Afrimi <dafrimi@nvidia.com>
Signed-off-by: default avatarroot <dafrimi@nvidia.com>
parent f3a5ee70
...@@ -405,6 +405,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -405,6 +405,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"NemotronHForCausalLM": _HfExamplesInfo( "NemotronHForCausalLM": _HfExamplesInfo(
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
), ),
"NemotronHPuzzleForCausalLM": _HfExamplesInfo(
"",
trust_remote_code=True,
is_available_online=False,
),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"Olmo3ForCausalLM": _HfExamplesInfo("allenai/Olmo-3-7B-Instruct"), "Olmo3ForCausalLM": _HfExamplesInfo("allenai/Olmo-3-7B-Instruct"),
......
...@@ -603,4 +603,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -603,4 +603,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"FalconMambaForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig, "NemotronHForCausalLM": NemotronHForCausalLMConfig,
"NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig,
} }
...@@ -354,8 +354,12 @@ class NemotronHMoEDecoderLayer(nn.Module): ...@@ -354,8 +354,12 @@ class NemotronHMoEDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
# Get per-layer config for heterogeneous models if exsist
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config
self.mixer = NemotronHMoE( self.mixer = NemotronHMoE(
config, layer_config,
quant_config=quant_config, quant_config=quant_config,
parallel_config=parallel_config, parallel_config=parallel_config,
prefix=f"{prefix}.mixer", prefix=f"{prefix}.mixer",
...@@ -479,6 +483,9 @@ class NemotronHAttention(nn.Module): ...@@ -479,6 +483,9 @@ class NemotronHAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
# Get per-layer sliding window from config (for heterogeneous models)
sliding_window = getattr(config, "sliding_window", None)
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -487,6 +494,7 @@ class NemotronHAttention(nn.Module): ...@@ -487,6 +494,7 @@ class NemotronHAttention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
) )
def forward( def forward(
...@@ -514,8 +522,12 @@ class NemotronHAttentionDecoderLayer(nn.Module): ...@@ -514,8 +522,12 @@ class NemotronHAttentionDecoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
# Get per-layer config for heterogeneous models if exsist
get_layer_config = getattr(config, "get_nemotron_h_config_for_layer", None)
layer_config = get_layer_config(layer_idx) if get_layer_config else config
self.mixer = NemotronHAttention( self.mixer = NemotronHAttention(
config, layer_config,
layer_idx, layer_idx,
model_config, model_config,
cache_config, cache_config,
...@@ -631,6 +643,34 @@ class NemotronHModel(nn.Module): ...@@ -631,6 +643,34 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual) hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states return hidden_states
def _get_max_n_routed_experts(self) -> int:
"""Get max n_routed_experts from config or block_configs for puzzle models.
For heterogeneous models with varying expert counts per layer,
returns the MAX to ensure all expert weights can be loaded.
"""
# First try top-level attribute
n_routed_experts = getattr(self.config, "n_routed_experts", None)
if n_routed_experts is not None:
return n_routed_experts
# For puzzle models, get MAX from all MoE blocks in block_configs
# (different layers may have different expert counts)
max_experts = 0
block_configs = getattr(self.config, "block_configs", None)
if block_configs:
for block in block_configs:
if isinstance(block, dict):
if block.get("block_type") == "moe":
max_experts = max(max_experts, block.get("n_routed_experts", 0))
else:
# HF converts dicts to objects with attributes
if getattr(block, "block_type", "") == "moe":
max_experts = max(
max_experts, getattr(block, "n_routed_experts", 0)
)
return max_experts
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe: if self.has_moe:
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
...@@ -643,7 +683,7 @@ class NemotronHModel(nn.Module): ...@@ -643,7 +683,7 @@ class NemotronHModel(nn.Module):
ckpt_gate_proj_name="up_proj", ckpt_gate_proj_name="up_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="", ckpt_up_proj_name="",
num_experts=self.config.n_routed_experts, num_experts=self._get_max_n_routed_experts(),
num_redundant_experts=getattr(self, "num_redundant_experts", 0), num_redundant_experts=getattr(self, "num_redundant_experts", 0),
) )
return expert_params_mapping return expert_params_mapping
......
...@@ -163,6 +163,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -163,6 +163,7 @@ _TEXT_GENERATION_MODELS = {
"MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"), "MiMoV2FlashForCausalLM": ("mimo_v2_flash", "MiMoV2FlashForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"), "NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"NemotronHPuzzleForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
......
...@@ -81,6 +81,26 @@ class ModelArchConfigConvertorBase: ...@@ -81,6 +81,26 @@ class ModelArchConfigConvertorBase:
self.hf_text_config, attributes, default_factory=default_factory self.hf_text_config, attributes, default_factory=default_factory
) )
def get_num_experts_from_block_configs(self) -> int:
"""Check block_configs for heterogeneous models (e.g., NemotronH).
For heterogeneous models with varying expert counts per layer,
returns the MAX to ensure all expert weights can be loaded.
"""
max_experts = 0
block_configs = getattr(self.hf_text_config, "block_configs", None)
if block_configs:
for block in block_configs:
if isinstance(block, dict):
if block.get("block_type", "") == "moe":
max_experts = max(max_experts, block.get("n_routed_experts", 0))
else:
if getattr(block, "block_type", "") == "moe":
max_experts = max(
max_experts, getattr(block, "n_routed_experts", 0)
)
return max_experts
def get_num_experts(self) -> int: def get_num_experts(self) -> int:
"""Returns the number of experts in the model.""" """Returns the number of experts in the model."""
num_expert_names = [ num_expert_names = [
...@@ -89,13 +109,16 @@ class ModelArchConfigConvertorBase: ...@@ -89,13 +109,16 @@ class ModelArchConfigConvertorBase:
"n_routed_experts", # DeepSeek "n_routed_experts", # DeepSeek
"num_local_experts", # Mixtral "num_local_experts", # Mixtral
] ]
num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0) num_experts = getattr_iter(self.hf_text_config, num_expert_names, 0)
if isinstance(num_experts, list): if isinstance(num_experts, list):
# Ernie VL's remote code uses list[int]... # Ernie VL's remote code uses list[int]...
# The values are always the same so we just take the first one. # The values are always the same so we just take the first one.
return num_experts[0] return num_experts[0]
# Coerce to 0 if explicitly set to None
return num_experts or 0 if not num_experts:
num_experts = self.get_num_experts_from_block_configs()
return num_experts
@final @final
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment