Commit 2967de06 authored by thomwolf's avatar thomwolf
Browse files

adding intialization to bert

parent a6bcfb80
......@@ -28,7 +28,7 @@ import numpy as np
import tensorflow as tf
from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
......@@ -100,9 +100,16 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
super(TFBertEmbeddings, self).__init__(**kwargs)
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings')
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name='token_type_embeddings')
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
......@@ -117,8 +124,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.word_embeddings = self.add_weight(
"weight",
shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
initializer=get_initializer(self.initializer_range))
super(TFBertEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding", training=False):
......@@ -192,9 +198,15 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(self.all_head_size, name='query')
self.key = tf.keras.layers.Dense(self.all_head_size, name='key')
self.value = tf.keras.layers.Dense(self.all_head_size, name='value')
self.query = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='query')
self.key = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='key')
self.value = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='value')
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
......@@ -247,7 +259,9 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertSelfOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
......@@ -281,7 +295,9 @@ class TFBertAttention(tf.keras.layers.Layer):
class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertIntermediate, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name='dense')
self.dense = tf.keras.layers.Dense(config.intermediate_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
......@@ -296,7 +312,9 @@ class TFBertIntermediate(tf.keras.layers.Layer):
class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
......@@ -364,7 +382,10 @@ class TFBertEncoder(tf.keras.layers.Layer):
class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
activation='tanh',
name='dense')
def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
......@@ -377,7 +398,9 @@ class TFBertPooler(tf.keras.layers.Layer):
class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertPredictionHeadTransform, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense')
self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
......@@ -428,7 +451,9 @@ class TFBertMLMHead(tf.keras.layers.Layer):
class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
self.seq_relationship = tf.keras.layers.Dense(2,
kernel_initializer=get_initializer(self.config.initializer_range),
name='seq_relationship')
def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
......@@ -454,8 +479,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
"""
raise NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
......@@ -819,7 +842,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
......@@ -869,7 +894,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier')
self.classifier = tf.keras.layers.Dense(1,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)):
......@@ -946,7 +973,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
......@@ -996,7 +1025,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='qa_outputs')
def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs)
......
......@@ -474,3 +474,12 @@ def shape_list(x):
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range.
Args:
initializer_range: float, initializer range for stddev.
Returns:
TruncatedNormal initializer with stddev = `initializer_range`.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment