Unverified Commit e677479c authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Mamba`] from pretrained issue with `self.embeddings` (#29851)



* nit

* update

* oups

* Update src/transformers/models/mamba/modeling_mamba.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

---------
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 441de62f
......@@ -501,8 +501,15 @@ class MambaModel(MambaPreTrainedModel):
self.gradient_checkpointing = False
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
for k in state_dict:
if "embedding." in k:
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
break
def get_input_embeddings(self):
return self.embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment