Commit 4b956b2a authored by thomwolf's avatar thomwolf
Browse files

add layer_norm_epsilon configuration for transformer xl

parent b97af8cc
...@@ -95,10 +95,43 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -95,10 +95,43 @@ class TransfoXLConfig(PretrainedConfig):
init_range=0.01, init_range=0.01,
proj_init_std=0.01, proj_init_std=0.01,
init_std=0.02, init_std=0.02,
layer_norm_epsilon=1e-5,
**kwargs): **kwargs):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
""" """
super(TransfoXLConfig, self).__init__(**kwargs) super(TransfoXLConfig, self).__init__(**kwargs)
self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -106,39 +139,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -106,39 +139,7 @@ class TransfoXLConfig(PretrainedConfig):
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif not isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)") " or the path to a pretrained model config file (str)")
......
...@@ -84,8 +84,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file ...@@ -84,8 +84,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo = tf_model(tf_inputs, training=False) # build the network tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None, pt_model = pt_model_class.from_pretrained(None,
config=config, config=config,
state_dict=torch.load(pytorch_checkpoint_path, state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu')) map_location='cpu'))
pt_inputs = torch.tensor(inputs_list) pt_inputs = torch.tensor(inputs_list)
with torch.no_grad(): with torch.no_grad():
...@@ -124,7 +124,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with ...@@ -124,7 +124,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name)) print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name))
print("-" * 100) print("-" * 100)
if 'finetuned' in shortcut_name: if 'finetuned' in shortcut_name:
print(" Skipping fintenued checkpoint ") print(" Skipping finetuned checkpoint ")
continue continue
config_file = cached_path(aws_config_map[shortcut_name], force_download=True) config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True) model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
......
...@@ -91,8 +91,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None ...@@ -91,8 +91,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
name = name[1:] # Remove level zero name = name[1:] # Remove level zero
# When should we transpose the weights
transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name)
# Convert standard TF2.0 names in PyTorch names # Convert standard TF2.0 names in PyTorch names
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma': if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight' name[-1] = 'weight'
if name[-1] == 'beta': if name[-1] == 'beta':
......
...@@ -66,7 +66,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer): ...@@ -66,7 +66,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
class TFPositionwiseFF(tf.keras.layers.Layer): class TFPositionwiseFF(tf.keras.layers.Layer):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, **kwargs): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, **kwargs):
super(TFPositionwiseFF, self).__init__(**kwargs) super(TFPositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model self.d_model = d_model
...@@ -75,10 +75,10 @@ class TFPositionwiseFF(tf.keras.layers.Layer): ...@@ -75,10 +75,10 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
self.layer_1 = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu, name='CoreNet_._0') self.layer_1 = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu, name='CoreNet_._0')
self.drop_1 = tf.keras.layers.Dropout(dropout) self.drop_1 = tf.keras.layers.Dropout(dropout)
self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._2') self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._3')
self.drop_2 = tf.keras.layers.Dropout(dropout) self.drop_2 = tf.keras.layers.Dropout(dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm') self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -109,7 +109,8 @@ class TFPositionwiseFF(tf.keras.layers.Layer): ...@@ -109,7 +109,8 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False, **kwargs): r_r_bias=None, r_w_bias=None, output_attentions=False,
layer_norm_epsilon=1e-5, **kwargs):
super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs) super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs)
self.output_attentions = output_attentions self.output_attentions = output_attentions
...@@ -124,7 +125,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -124,7 +125,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
self.dropatt = tf.keras.layers.Dropout(dropatt) self.dropatt = tf.keras.layers.Dropout(dropatt)
self.o_net = tf.keras.layers.Dense(d_model, use_bias=False, name='o_net') self.o_net = tf.keras.layers.Dense(d_model, use_bias=False, name='o_net')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm') self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')
self.scale = 1 / (d_head ** 0.5) self.scale = 1 / (d_head ** 0.5)
...@@ -247,6 +248,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -247,6 +248,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_w_bias=None, r_w_bias=None,
r_r_bias=None, r_r_bias=None,
output_attentions=False, output_attentions=False,
layer_norm_epsilon=1e-5,
**kwargs): **kwargs):
super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs) super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs)
...@@ -254,9 +256,12 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -254,9 +256,12 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_head, dropout, tgt_len=tgt_len, ext_len=ext_len, d_head, dropout, tgt_len=tgt_len, ext_len=ext_len,
mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm,
r_w_bias=r_w_bias, r_r_bias=r_r_bias, r_w_bias=r_w_bias, r_r_bias=r_r_bias,
output_attentions=output_attentions, name='dec_attn') output_attentions=output_attentions,
layer_norm_epsilon=layer_norm_epsilon, name='dec_attn')
self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout, self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=pre_lnorm, name='pos_ff') pre_lnorm=pre_lnorm,
layer_norm_epsilon=layer_norm_epsilon,
name='pos_ff')
def call(self, inputs, training=False): def call(self, inputs, training=False):
dec_inp, r, dec_attn_mask, mems, head_mask = inputs dec_inp, r, dec_attn_mask, mems, head_mask = inputs
...@@ -300,7 +305,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -300,7 +305,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
d_emb_i = self.d_embed // (self.div_val ** i) d_emb_i = self.d_embed // (self.div_val ** i)
self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj), self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj),
trainable=True, trainable=True,
name='emb_projs._{}'.format(i))) name='emb_projs_._{}'.format(i)))
super(TFAdaptiveEmbedding, self).build(input_shape) super(TFAdaptiveEmbedding, self).build(input_shape)
def call(self, inp): def call(self, inp):
...@@ -368,6 +373,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -368,6 +373,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
r_w_bias=None if self.untie_r else self.r_w_bias, r_w_bias=None if self.untie_r else self.r_w_bias,
r_r_bias=None if self.untie_r else self.r_r_bias, r_r_bias=None if self.untie_r else self.r_r_bias,
output_attentions=self.output_attentions, output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon,
name='layers_._{}'.format(i)) name='layers_._{}'.format(i))
) )
else: # learnable embeddings and absolute embeddings else: # learnable embeddings and absolute embeddings
......
...@@ -194,7 +194,7 @@ class PositionalEmbedding(nn.Module): ...@@ -194,7 +194,7 @@ class PositionalEmbedding(nn.Module):
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
super(PositionwiseFF, self).__init__() super(PositionwiseFF, self).__init__()
self.d_model = d_model self.d_model = d_model
...@@ -208,7 +208,7 @@ class PositionwiseFF(nn.Module): ...@@ -208,7 +208,7 @@ class PositionwiseFF(nn.Module):
nn.Dropout(dropout), nn.Dropout(dropout),
) )
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
...@@ -232,7 +232,8 @@ class PositionwiseFF(nn.Module): ...@@ -232,7 +232,8 @@ class PositionwiseFF(nn.Module):
class RelPartialLearnableMultiHeadAttn(nn.Module): class RelPartialLearnableMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False): r_r_bias=None, r_w_bias=None, output_attentions=False,
layer_norm_epsilon=1e-5):
super(RelPartialLearnableMultiHeadAttn, self).__init__() super(RelPartialLearnableMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions self.output_attentions = output_attentions
...@@ -247,7 +248,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -247,7 +248,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
self.dropatt = nn.Dropout(dropatt) self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
self.scale = 1 / (d_head ** 0.5) self.scale = 1 / (d_head ** 0.5)
...@@ -359,14 +360,15 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -359,14 +360,15 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
class RelPartialLearnableDecoderLayer(nn.Module): class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5,
**kwargs): **kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__() super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs) d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'),
layer_norm_epsilon=layer_norm_epsilon)
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None): def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
...@@ -613,7 +615,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -613,7 +615,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias, r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions) output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon)
) )
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
raise NotImplementedError # Removed them to avoid maintaining dead code raise NotImplementedError # Removed them to avoid maintaining dead code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment