Unverified Commit 521d0d99 authored by icyboy™'s avatar icyboy™ Committed by GitHub
Browse files

fix dbrx & opt model prefix bug (#2201)

* Update idefics_causal_lm.py

Fix syntax issues

* fix dbrx & opt model prefix bug
parent 05c094fc
...@@ -711,7 +711,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): ...@@ -711,7 +711,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
else: else:
prefix = f"{prefix}.transformer" prefix = f"{prefix}.transformer"
self.model = DbrxModel(config, weights) self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
......
...@@ -757,7 +757,7 @@ class OPTForCausalLM(OPTPreTrainedModel): ...@@ -757,7 +757,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
else: else:
prefix = f"{prefix}.model" prefix = f"{prefix}.model"
self.model = OPTModel(config, weights) self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment