Unverified Commit 108a2728 authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Add get_expert_mapping to NemotronHModel (for LoRA support) (#31539)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent 578c8f51
...@@ -632,14 +632,7 @@ class NemotronHModel(nn.Module): ...@@ -632,14 +632,7 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual) hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if self.has_moe: if self.has_moe:
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
...@@ -653,8 +646,19 @@ class NemotronHModel(nn.Module): ...@@ -653,8 +646,19 @@ class NemotronHModel(nn.Module):
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts,
num_redundant_experts=getattr(self, "num_redundant_experts", 0), num_redundant_experts=getattr(self, "num_redundant_experts", 0),
) )
else: return expert_params_mapping
expert_params_mapping = []
return []
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment