Commit 4b956b2a authored by thomwolf's avatar thomwolf
Browse files

add layer_norm_epsilon configuration for transformer xl

parent b97af8cc
......@@ -95,10 +95,43 @@ class TransfoXLConfig(PretrainedConfig):
init_range=0.01,
proj_init_std=0.01,
init_std=0.02,
layer_norm_epsilon=1e-5,
**kwargs):
"""Constructs TransfoXLConfig.
"""
super(TransfoXLConfig, self).__init__(**kwargs)
self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
......@@ -106,39 +139,7 @@ class TransfoXLConfig(PretrainedConfig):
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
else:
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
......
......@@ -84,8 +84,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tfo = tf_model(tf_inputs, training=False) # build the network
pt_model = pt_model_class.from_pretrained(None,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
config=config,
state_dict=torch.load(pytorch_checkpoint_path,
map_location='cpu'))
pt_inputs = torch.tensor(inputs_list)
with torch.no_grad():
......@@ -124,7 +124,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name))
print("-" * 100)
if 'finetuned' in shortcut_name:
print(" Skipping fintenued checkpoint ")
print(" Skipping finetuned checkpoint ")
continue
config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
......
......@@ -91,8 +91,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
name = name[1:] # Remove level zero
# When should we transpose the weights
transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name)
# Convert standard TF2.0 names in PyTorch names
transpose = bool(name[-1] == 'kernel')
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
name[-1] = 'weight'
if name[-1] == 'beta':
......
......@@ -66,7 +66,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
class TFPositionwiseFF(tf.keras.layers.Layer):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, **kwargs):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, **kwargs):
super(TFPositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model
......@@ -75,10 +75,10 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
self.layer_1 = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu, name='CoreNet_._0')
self.drop_1 = tf.keras.layers.Dropout(dropout)
self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._2')
self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._3')
self.drop_2 = tf.keras.layers.Dropout(dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')
self.pre_lnorm = pre_lnorm
......@@ -109,7 +109,8 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False, **kwargs):
r_r_bias=None, r_w_bias=None, output_attentions=False,
layer_norm_epsilon=1e-5, **kwargs):
super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs)
self.output_attentions = output_attentions
......@@ -124,7 +125,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
self.dropatt = tf.keras.layers.Dropout(dropatt)
self.o_net = tf.keras.layers.Dense(d_model, use_bias=False, name='o_net')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm')
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name='layer_norm')
self.scale = 1 / (d_head ** 0.5)
......@@ -247,6 +248,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_w_bias=None,
r_r_bias=None,
output_attentions=False,
layer_norm_epsilon=1e-5,
**kwargs):
super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs)
......@@ -254,9 +256,12 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_head, dropout, tgt_len=tgt_len, ext_len=ext_len,
mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm,
r_w_bias=r_w_bias, r_r_bias=r_r_bias,
output_attentions=output_attentions, name='dec_attn')
output_attentions=output_attentions,
layer_norm_epsilon=layer_norm_epsilon, name='dec_attn')
self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=pre_lnorm, name='pos_ff')
pre_lnorm=pre_lnorm,
layer_norm_epsilon=layer_norm_epsilon,
name='pos_ff')
def call(self, inputs, training=False):
dec_inp, r, dec_attn_mask, mems, head_mask = inputs
......@@ -300,7 +305,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
d_emb_i = self.d_embed // (self.div_val ** i)
self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj),
trainable=True,
name='emb_projs._{}'.format(i)))
name='emb_projs_._{}'.format(i)))
super(TFAdaptiveEmbedding, self).build(input_shape)
def call(self, inp):
......@@ -368,6 +373,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
r_w_bias=None if self.untie_r else self.r_w_bias,
r_r_bias=None if self.untie_r else self.r_r_bias,
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon,
name='layers_._{}'.format(i))
)
else: # learnable embeddings and absolute embeddings
......
......@@ -194,7 +194,7 @@ class PositionalEmbedding(nn.Module):
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
......@@ -208,7 +208,7 @@ class PositionwiseFF(nn.Module):
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
self.pre_lnorm = pre_lnorm
......@@ -232,7 +232,8 @@ class PositionwiseFF(nn.Module):
class RelPartialLearnableMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None, output_attentions=False):
r_r_bias=None, r_w_bias=None, output_attentions=False,
layer_norm_epsilon=1e-5):
super(RelPartialLearnableMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
......@@ -247,7 +248,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
self.scale = 1 / (d_head ** 0.5)
......@@ -359,14 +360,15 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
pre_lnorm=kwargs.get('pre_lnorm'),
layer_norm_epsilon=layer_norm_epsilon)
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
......@@ -613,7 +615,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
output_attentions=self.output_attentions,
layer_norm_epsilon=config.layer_norm_epsilon)
)
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
raise NotImplementedError # Removed them to avoid maintaining dead code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment