Unverified Commit 05c094fc authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Consistently take `prefix` in model constructors (#2191)

* Consistently take `prefix` in model constructors

* Release test check fix

* Misc refactor-related fixes
parent 67ef0649
......@@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
def __init__(self, prefix: str, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
)
def forward(
......@@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
......@@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
......@@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
)
else:
self.project_in = None
......@@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
......@@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing
def forward(
......@@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
def forward(
......
......@@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
prefix=f"{prefix}.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
......@@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
......
......@@ -878,10 +878,6 @@ class FlashCausalLM(Model):
)
config.quantize = quantize
config.speculator = speculator
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
......@@ -900,13 +896,22 @@ class FlashCausalLM(Model):
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
# Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None)
if num_kv_heads is not None:
break
if num_kv_heads is None:
# Final overide for GPT2
num_kv_heads = config.n_head
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment