Unverified Commit a379d553 authored by drbh's avatar drbh Committed by GitHub
Browse files

Fix the prefix for OPT model in opt_modelling.py #2370 (CI RUN) (#2371)



* Fix the bug

* fix: run lints

* fix: small syntax tweak

---------
Co-authored-by: default avatarSadra Barikbin <sadraqazvin1@yahoo.com>
parent 21267f3c
import pytest
@pytest.fixture(scope="module")
def opt_sharded_handle(launcher):
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def opt_sharded(opt_sharded_handle):
await opt_sharded_handle.health(300)
return opt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_opt(opt_sharded):
pass
......@@ -98,7 +98,9 @@ class OPTLearnedPositionalEmbedding(nn.Module):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
weights.get_tensor(
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
)
)
def forward(
......@@ -315,7 +317,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"{prefix}.decoder.layers.{layer_id}"
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
......@@ -437,15 +439,17 @@ class OPTDecoder(OPTPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
prefix = prefix + "." if prefix else ""
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
prefix=f"{prefix}decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config,
prefix=f"{prefix}.decoder.project_out",
prefix=f"{prefix}decoder.project_out",
weights=weights,
bias=False,
)
......@@ -455,7 +459,7 @@ class OPTDecoder(OPTPreTrainedModel):
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config,
prefix=f"{prefix}.decoder.project_in",
prefix=f"{prefix}decoder.project_in",
weights=weights,
bias=False,
)
......@@ -467,7 +471,7 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
......@@ -752,15 +756,12 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
config,
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
weights=weights,
)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment