Unverified Commit 343041c4 authored by Sky Lee's avatar Sky Lee Committed by GitHub
Browse files

[model] Reduce medusa weight (#10454)


Signed-off-by: default avatarskylee-01 <497627264@qq.com>
parent ed701ca9
......@@ -61,6 +61,17 @@ class Medusa(nn.Module):
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
if getattr(config, "original_lm_head", False):
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_heads = [
self.lm_head for _ in range(self.config.num_heads)
]
else:
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
......@@ -172,6 +183,9 @@ class Medusa(nn.Module):
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
elif (getattr(self.config, "original_lm_head", False)
and name == "lm_heads.0.weight"):
weights_map["lm_head.weight"] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment