"docs/vscode:/vscode.git/clone" did not exist on "f0ef37233ea0ba5251edaea7362984110411e7eb"
Unverified Commit 108a2728 authored by danisereb's avatar danisereb Committed by GitHub
Browse files

Add get_expert_mapping to NemotronHModel (for LoRA support) (#31539)


Signed-off-by: default avatarDaniel Serebrenik <daserebrenik@nvidia.com>
parent 578c8f51
......@@ -632,14 +632,7 @@ class NemotronHModel(nn.Module):
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
......@@ -653,8 +646,19 @@ class NemotronHModel(nn.Module):
num_experts=self.config.n_routed_experts,
num_redundant_experts=getattr(self, "num_redundant_experts", 0),
)
else:
expert_params_mapping = []
return expert_params_mapping
return []
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment