Unverified Commit 3e073e66 authored by sohamparikh's avatar sohamparikh Committed by GitHub
Browse files

[Bugfix] load fc bias from config for eagle (#8790)

parent c2395367
......@@ -44,7 +44,7 @@ class EAGLE(nn.Module):
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=False)
bias=getattr(self.config, "bias", False))
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
......@@ -136,10 +136,18 @@ class EAGLE(nn.Module):
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc."):
elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment