Unverified Commit 5b45422b authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Remove n_ctx from configs (#14165)

* Remove n_ctx from configs

* Fix GPTJ and OpenAIGPT, both are acceptable breaking changes as there are no configs such that it breaks

* Remove unecessary n_positions from TFOpenAIGPT
parent be236361
{ {
"initializer_range": 0.02, "initializer_range": 0.02,
"layer_norm_epsilon": 0.00001, "layer_norm_epsilon": 0.00001,
"n_ctx": 1024,
"n_embd": 768, "n_embd": 768,
"n_head": 12, "n_head": 12,
"n_layer": 6, "n_layer": 6,
......
...@@ -653,7 +653,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -653,7 +653,7 @@ class CLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, n_ctx, transformer.width] # text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
......
...@@ -521,7 +521,7 @@ class FlaxCLIPTextTransformer(nn.Module): ...@@ -521,7 +521,7 @@ class FlaxCLIPTextTransformer(nn.Module):
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state) last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, n_ctx, transformer.width] # text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence) # take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)] pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
......
...@@ -41,8 +41,6 @@ class CTRLConfig(PretrainedConfig): ...@@ -41,8 +41,6 @@ class CTRLConfig(PretrainedConfig):
n_positions (:obj:`int`, `optional`, defaults to 256): n_positions (:obj:`int`, `optional`, defaults to 256):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
n_ctx (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the causal mask (usually same as n_positions).
n_embd (:obj:`int`, `optional`, defaults to 1280): n_embd (:obj:`int`, `optional`, defaults to 1280):
Dimensionality of the embeddings and hidden states. Dimensionality of the embeddings and hidden states.
dff (:obj:`int`, `optional`, defaults to 8192): dff (:obj:`int`, `optional`, defaults to 8192):
...@@ -92,7 +90,6 @@ class CTRLConfig(PretrainedConfig): ...@@ -92,7 +90,6 @@ class CTRLConfig(PretrainedConfig):
self, self,
vocab_size=246534, vocab_size=246534,
n_positions=256, n_positions=256,
n_ctx=256,
n_embd=1280, n_embd=1280,
dff=8192, dff=8192,
n_layer=48, n_layer=48,
...@@ -111,7 +108,6 @@ class CTRLConfig(PretrainedConfig): ...@@ -111,7 +108,6 @@ class CTRLConfig(PretrainedConfig):
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
......
...@@ -54,8 +54,6 @@ class GPT2Config(PretrainedConfig): ...@@ -54,8 +54,6 @@ class GPT2Config(PretrainedConfig):
n_positions (:obj:`int`, `optional`, defaults to 1024): n_positions (:obj:`int`, `optional`, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
n_ctx (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the causal mask (usually same as n_positions).
n_embd (:obj:`int`, `optional`, defaults to 768): n_embd (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the embeddings and hidden states. Dimensionality of the embeddings and hidden states.
n_layer (:obj:`int`, `optional`, defaults to 12): n_layer (:obj:`int`, `optional`, defaults to 12):
...@@ -144,7 +142,6 @@ class GPT2Config(PretrainedConfig): ...@@ -144,7 +142,6 @@ class GPT2Config(PretrainedConfig):
self, self,
vocab_size=50257, vocab_size=50257,
n_positions=1024, n_positions=1024,
n_ctx=1024,
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
...@@ -169,7 +166,6 @@ class GPT2Config(PretrainedConfig): ...@@ -169,7 +166,6 @@ class GPT2Config(PretrainedConfig):
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
......
...@@ -66,13 +66,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -66,13 +66,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs): def __init__(self, nx, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation] # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.n_ctx = n_ctx
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -185,12 +184,12 @@ class TFMLP(tf.keras.layers.Layer): ...@@ -185,12 +184,12 @@ class TFMLP(tf.keras.layers.Layer):
class TFBlock(tf.keras.layers.Layer): class TFBlock(tf.keras.layers.Layer):
def __init__(self, n_ctx, config, scale=False, **kwargs): def __init__(self, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
nx = config.n_embd nx = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
self.attn = TFAttention(nx, n_ctx, config, scale, name="attn") self.attn = TFAttention(nx, config, scale, name="attn")
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
self.mlp = TFMLP(inner_dim, config, name="mlp") self.mlp = TFMLP(inner_dim, config, name="mlp")
...@@ -233,7 +232,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -233,7 +232,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
) )
self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
def build(self, input_shape): def build(self, input_shape):
......
...@@ -33,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -33,7 +33,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
num_layers=config_json["n_layer"], num_layers=config_json["n_layer"],
num_heads=config_json["n_head"], num_heads=config_json["n_head"],
attention_types=config_json["attention_types"], attention_types=config_json["attention_types"],
max_position_embeddings=config_json["n_ctx"], max_position_embeddings=config_json["n_positions"],
resid_dropout=config_json["res_dropout"], resid_dropout=config_json["res_dropout"],
embed_dropout=config_json["embed_dropout"], embed_dropout=config_json["embed_dropout"],
attention_dropout=config_json["attn_dropout"], attention_dropout=config_json["attn_dropout"],
......
...@@ -42,8 +42,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -42,8 +42,6 @@ class GPTJConfig(PretrainedConfig):
n_positions (:obj:`int`, `optional`, defaults to 2048): n_positions (:obj:`int`, `optional`, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
n_ctx (:obj:`int`, `optional`, defaults to 2048):
Dimensionality of the causal mask (usually same as n_positions).
n_embd (:obj:`int`, `optional`, defaults to 4096): n_embd (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the embeddings and hidden states. Dimensionality of the embeddings and hidden states.
n_layer (:obj:`int`, `optional`, defaults to 28): n_layer (:obj:`int`, `optional`, defaults to 28):
...@@ -96,7 +94,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -96,7 +94,6 @@ class GPTJConfig(PretrainedConfig):
self, self,
vocab_size=50400, vocab_size=50400,
n_positions=2048, n_positions=2048,
n_ctx=2048,
n_embd=4096, n_embd=4096,
n_layer=28, n_layer=28,
n_head=16, n_head=16,
...@@ -115,7 +112,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -115,7 +112,6 @@ class GPTJConfig(PretrainedConfig):
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
......
...@@ -99,7 +99,7 @@ class GPTJAttention(nn.Module): ...@@ -99,7 +99,7 @@ class GPTJAttention(nn.Module):
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
""" """
Splits n_ctx dim into attn_head_size and num_attention_heads Splits hidden dim into attn_head_size and num_attention_heads
""" """
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(*new_shape) tensor = tensor.view(*new_shape)
...@@ -114,7 +114,7 @@ class GPTJAttention(nn.Module): ...@@ -114,7 +114,7 @@ class GPTJAttention(nn.Module):
def _merge_heads(self, tensor, num_attention_heads, attn_head_size): def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
""" """
Merges attn_head_size dim and num_attn_heads dim into n_ctx Merges attn_head_size dim and num_attn_heads dim into hidden dim
""" """
if len(tensor.shape) == 5: if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
...@@ -377,7 +377,7 @@ GPTJ_INPUTS_DOCSTRING = r""" ...@@ -377,7 +377,7 @@ GPTJ_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, n_ctx)`, `optional`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_dim)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
...@@ -444,7 +444,6 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -444,7 +444,6 @@ class GPTJModel(GPTJPreTrainedModel):
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)]) self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
self.init_weights() self.init_weights()
# Model parallel # Model parallel
...@@ -854,7 +853,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -854,7 +853,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = GPTJModel(config) self.transformer = GPTJModel(config)
self.score = nn.Linear(config.n_ctx, self.num_labels, bias=False) self.score = nn.Linear(config.n_positions, self.num_labels, bias=False)
self.init_weights() self.init_weights()
......
...@@ -88,7 +88,6 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -88,7 +88,6 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
config.vocab_size = ds_args.padded_vocab_size config.vocab_size = ds_args.padded_vocab_size
config.n_positions = ds_args.max_position_embeddings config.n_positions = ds_args.max_position_embeddings
config.n_ctx = ds_args.seq_length
config.n_embd = ds_args.hidden_size config.n_embd = ds_args.hidden_size
config.n_layer = ds_args.num_layers config.n_layer = ds_args.num_layers
config.n_head = ds_args.num_attention_heads config.n_head = ds_args.num_attention_heads
...@@ -121,10 +120,10 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -121,10 +120,10 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The position embeddings. # The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"] pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size] # Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size]
n_ctx = pos_embeddings.size(0) n_positions = pos_embeddings.size(0)
assert ( assert (
n_ctx == config.n_ctx n_positions == config.n_positions
), f"pos_embeddings.max_sequence_length={n_ctx} and config.n_ctx={config.n_ctx} don't match" ), f"pos_embeddings.max_sequence_length={n_positions} and config.n_positions={config.n_positions} don't match"
# Store the position embeddings. # Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings output_state_dict["transformer.wpe.weight"] = pos_embeddings
...@@ -173,7 +172,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -173,7 +172,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
) and weight_or_bias == "weight": ) and weight_or_bias == "weight":
# Insert a tensor of 1x1xDxD bias. # Insert a tensor of 1x1xDxD bias.
causal_mask = torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.float16)).view(1, 1, n_ctx, n_ctx) causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.float16)).view(
1, 1, n_positions, n_positions
)
output_state_dict[layer_name + ".attn.bias"] = causal_mask output_state_dict[layer_name + ".attn.bias"] = causal_mask
# Insert a "dummy" tensor for masked_bias. # Insert a "dummy" tensor for masked_bias.
...@@ -274,7 +275,6 @@ def main(): ...@@ -274,7 +275,6 @@ def main():
config = GPT2Config( config = GPT2Config(
vocab_size=50257, vocab_size=50257,
n_positions=1024, n_positions=1024,
n_ctx=1024,
n_embd=1024, n_embd=1024,
n_layer=24, n_layer=24,
n_head=16, n_head=16,
......
...@@ -42,8 +42,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -42,8 +42,6 @@ class OpenAIGPTConfig(PretrainedConfig):
n_positions (:obj:`int`, `optional`, defaults to 512): n_positions (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
n_ctx (:obj:`int`, `optional`, defaults to 512):
Dimensionality of the causal mask (usually same as n_positions).
n_embd (:obj:`int`, `optional`, defaults to 768): n_embd (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the embeddings and hidden states. Dimensionality of the embeddings and hidden states.
n_layer (:obj:`int`, `optional`, defaults to 12): n_layer (:obj:`int`, `optional`, defaults to 12):
...@@ -126,7 +124,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -126,7 +124,6 @@ class OpenAIGPTConfig(PretrainedConfig):
self, self,
vocab_size=40478, vocab_size=40478,
n_positions=512, n_positions=512,
n_ctx=512,
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
...@@ -145,7 +142,6 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -145,7 +142,6 @@ class OpenAIGPTConfig(PretrainedConfig):
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
......
...@@ -143,12 +143,14 @@ ACT_FNS = {"relu": nn.ReLU, "silu": silu, "gelu": gelu_new, "swish": silu} ...@@ -143,12 +143,14 @@ ACT_FNS = {"relu": nn.ReLU, "silu": silu, "gelu": gelu_new, "swish": silu}
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_positions, config, scale=False):
super().__init__() super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation] # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
)
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -246,10 +248,10 @@ class MLP(nn.Module): ...@@ -246,10 +248,10 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_positions, config, scale=False):
super().__init__() super().__init__()
nx = config.n_embd nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(nx, n_positions, config, scale)
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
...@@ -413,7 +415,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -413,7 +415,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd) self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
self.register_buffer("position_ids", torch.arange(config.n_positions)) self.register_buffer("position_ids", torch.arange(config.n_positions))
self.init_weights() self.init_weights()
......
...@@ -58,7 +58,7 @@ TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -58,7 +58,7 @@ TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False, **kwargs): def __init__(self, nx, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
...@@ -66,7 +66,6 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -66,7 +66,6 @@ class TFAttention(tf.keras.layers.Layer):
assert ( assert (
n_state % config.n_head == 0 n_state % config.n_head == 0
), f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}" ), f"Hidden dimension {n_state} not dividable by number of heads {config.n_head}"
self.n_ctx = n_ctx
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
...@@ -169,10 +168,10 @@ class TFMLP(tf.keras.layers.Layer): ...@@ -169,10 +168,10 @@ class TFMLP(tf.keras.layers.Layer):
class TFBlock(tf.keras.layers.Layer): class TFBlock(tf.keras.layers.Layer):
def __init__(self, n_ctx, config, scale=False, **kwargs): def __init__(self, config, scale=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
nx = config.n_embd nx = config.n_embd
self.attn = TFAttention(nx, n_ctx, config, scale, name="attn") self.attn = TFAttention(nx, config, scale, name="attn")
self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1") self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_1")
self.mlp = TFMLP(4 * nx, config, name="mlp") self.mlp = TFMLP(4 * nx, config, name="mlp")
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2") self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
...@@ -210,7 +209,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): ...@@ -210,7 +209,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed" config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
) )
self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)]
def build(self, input_shape): def build(self, input_shape):
with tf.name_scope("positions_embed"): with tf.name_scope("positions_embed"):
......
...@@ -114,7 +114,6 @@ class CTRLModelTester: ...@@ -114,7 +114,6 @@ class CTRLModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
......
...@@ -95,7 +95,6 @@ class FlaxGPT2ModelTester: ...@@ -95,7 +95,6 @@ class FlaxGPT2ModelTester:
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads, n_head=self.num_attention_heads,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
use_cache=False, use_cache=False,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
......
...@@ -155,7 +155,6 @@ class GPT2ModelTester: ...@@ -155,7 +155,6 @@ class GPT2ModelTester:
resid_pdrop=self.hidden_dropout_prob, resid_pdrop=self.hidden_dropout_prob,
attn_pdrop=self.attention_probs_dropout_prob, attn_pdrop=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
use_cache=True, use_cache=True,
......
...@@ -142,7 +142,6 @@ class GPTJModelTester: ...@@ -142,7 +142,6 @@ class GPTJModelTester:
hidden_dropout_prob=self.hidden_dropout_prob, hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
use_cache=True, use_cache=True,
......
...@@ -90,7 +90,6 @@ class OpenAIGPTModelTester: ...@@ -90,7 +90,6 @@ class OpenAIGPTModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
......
...@@ -97,7 +97,6 @@ class TFCTRLModelTester(object): ...@@ -97,7 +97,6 @@ class TFCTRLModelTester(object):
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
......
...@@ -100,7 +100,6 @@ class TFGPT2ModelTester: ...@@ -100,7 +100,6 @@ class TFGPT2ModelTester:
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment