"vscode:/vscode.git/clone" did not exist on "43b2f0981cd4fe597369909a007c315990c072e8"
Unverified Commit 05c094fc authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Consistently take `prefix` in model constructors (#2191)

* Consistently take `prefix` in model constructors

* Release test check fix

* Misc refactor-related fixes
parent 67ef0649
...@@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module): ...@@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
""" """
def __init__(self, weights): def __init__(self, prefix: str, weights):
super().__init__() super().__init__()
self.offset = 2 self.offset = 2
self.weight = nn.Parameter( self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight") weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
) )
def forward( def forward(
...@@ -311,11 +311,11 @@ class OPTAttention(nn.Module): ...@@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module): class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights): def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__() super().__init__()
self.process_group = weights.process_group self.process_group = weights.process_group
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}" prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention( self.self_attn = OPTAttention(
config, config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
...@@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel): ...@@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel): class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
...@@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
self.embed_positions = OPTLearnedPositionalEmbedding(weights) self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load( self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
) )
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load( self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
) )
else: else:
self.project_in = None self.project_in = None
...@@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load( self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
) )
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
OPTDecoderLayer(layer_id, config, weights) OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel): class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights): def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config) super().__init__(config)
self.decoder = OPTDecoder(config, weights) self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing # Initialize weights and apply final processing
def forward( def forward(
...@@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel): ...@@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel): class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights): def __init__(self, prefix, config, weights):
super().__init__(config) super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(config, weights) self.model = OPTModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
) )
def forward( def forward(
......
...@@ -248,16 +248,16 @@ class PhiBlock(nn.Module): ...@@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks. # PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module): class PhiModel(nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.tp_rank = weights.process_group.rank() self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size() self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights prefix=f"{prefix}.embd.wte", weights=weights
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
PhiBlock(f"transformer.h.{layer_id}", config, weights) PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer) for layer_id in range(config.n_layer)
] ]
) )
...@@ -289,9 +289,15 @@ class PhiModel(nn.Module): ...@@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module): class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights) self.lm_head = PhiCausalLMHead(config, weights)
def forward( def forward(
......
...@@ -878,10 +878,6 @@ class FlashCausalLM(Model): ...@@ -878,10 +878,6 @@ class FlashCausalLM(Model):
) )
config.quantize = quantize config.quantize = quantize
config.speculator = speculator config.speculator = speculator
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
...@@ -900,13 +896,22 @@ class FlashCausalLM(Model): ...@@ -900,13 +896,22 @@ class FlashCausalLM(Model):
text_config = getattr(config, "text_config", None) text_config = getattr(config, "text_config", None)
if text_config is not None: if text_config is not None:
config = text_config config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None) # Order is important here.
for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]:
num_kv_heads = getattr(config, "num_attention_heads", None)
if num_kv_heads is not None:
break
if num_kv_heads is None: if num_kv_heads is None:
# Final overide for GPT2 raise ValueError("Cannot get the number of key/value heads")
num_kv_heads = config.n_head
self.num_kv_heads = num_kv_heads // self.process_group.size() self.num_kv_heads = num_kv_heads // self.process_group.size()
self.head_size = config.hidden_size // config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment