Unverified Commit 4b49c50f authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313)

parent 3905f854
...@@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module): ...@@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
prefix = f"{prefix}.model" if prefix else "model"
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
...@@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module): ...@@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
if not prefix: self.model = Qwen2Model(prefix, config, weights)
prefix = "model"
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else: else:
prefix = f"{prefix}.model" suffix = "lm_head"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights, weights=weights,
) )
self.max_past = config.sliding_window self.max_past = config.sliding_window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment