Unverified Commit 598bd74c authored by nathan's avatar nathan Committed by GitHub
Browse files

Fix weights loading for Apertus (#24100)


Signed-off-by: default avatarNathan Ranchin <nranchin@student.ethz.ch>
parent 24177984
...@@ -415,6 +415,12 @@ class ApertusModel(nn.Module): ...@@ -415,6 +415,12 @@ class ApertusModel(nn.Module):
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
# we need to load the buffers for beta and eps (XIELU)
for name, buffer in self.named_buffers():
if name.endswith(".beta") or name.endswith(".eps"):
params_dict[name] = buffer
loaded_params: set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment