Commit 2967de06 authored by thomwolf's avatar thomwolf
Browse files

adding intialization to bert

parent a6bcfb80
...@@ -28,7 +28,7 @@ import numpy as np ...@@ -28,7 +28,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel, get_initializer
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
...@@ -100,9 +100,16 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -100,9 +100,16 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
super(TFBertEmbeddings, self).__init__(**kwargs) super(TFBertEmbeddings, self).__init__(**kwargs)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.initializer_range = config.initializer_range
self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings') self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings,
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings') config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name='position_embeddings')
self.token_type_embeddings = tf.keras.layers.Embedding(config.type_vocab_size,
config.hidden_size,
embeddings_initializer=get_initializer(self.initializer_range),
name='token_type_embeddings')
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
...@@ -117,8 +124,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -117,8 +124,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
self.word_embeddings = self.add_weight( self.word_embeddings = self.add_weight(
"weight", "weight",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
initializer=tf.random_normal_initializer( initializer=get_initializer(self.initializer_range))
mean=0., stddev=self.hidden_size**-0.5))
super(TFBertEmbeddings, self).build(input_shape) super(TFBertEmbeddings, self).build(input_shape)
def call(self, inputs, mode="embedding", training=False): def call(self, inputs, mode="embedding", training=False):
...@@ -192,9 +198,15 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -192,9 +198,15 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = tf.keras.layers.Dense(self.all_head_size, name='query') self.query = tf.keras.layers.Dense(self.all_head_size,
self.key = tf.keras.layers.Dense(self.all_head_size, name='key') kernel_initializer=get_initializer(self.config.initializer_range),
self.value = tf.keras.layers.Dense(self.all_head_size, name='value') name='query')
self.key = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='key')
self.value = tf.keras.layers.Dense(self.all_head_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='value')
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
...@@ -247,7 +259,9 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -247,7 +259,9 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
class TFBertSelfOutput(tf.keras.layers.Layer): class TFBertSelfOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertSelfOutput, self).__init__(**kwargs) super(TFBertSelfOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
...@@ -281,7 +295,9 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -281,7 +295,9 @@ class TFBertAttention(tf.keras.layers.Layer):
class TFBertIntermediate(tf.keras.layers.Layer): class TFBertIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertIntermediate, self).__init__(**kwargs) super(TFBertIntermediate, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.intermediate_size, name='dense') self.dense = tf.keras.layers.Dense(config.intermediate_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
...@@ -296,7 +312,9 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -296,7 +312,9 @@ class TFBertIntermediate(tf.keras.layers.Layer):
class TFBertOutput(tf.keras.layers.Layer): class TFBertOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertOutput, self).__init__(**kwargs) super(TFBertOutput, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
...@@ -364,7 +382,10 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -364,7 +382,10 @@ class TFBertEncoder(tf.keras.layers.Layer):
class TFBertPooler(tf.keras.layers.Layer): class TFBertPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertPooler, self).__init__(**kwargs) super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense') self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
activation='tanh',
name='dense')
def call(self, hidden_states): def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
...@@ -377,7 +398,9 @@ class TFBertPooler(tf.keras.layers.Layer): ...@@ -377,7 +398,9 @@ class TFBertPooler(tf.keras.layers.Layer):
class TFBertPredictionHeadTransform(tf.keras.layers.Layer): class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertPredictionHeadTransform, self).__init__(**kwargs) super(TFBertPredictionHeadTransform, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, name='dense') self.dense = tf.keras.layers.Dense(config.hidden_size,
kernel_initializer=get_initializer(self.config.initializer_range),
name='dense')
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
...@@ -428,7 +451,9 @@ class TFBertMLMHead(tf.keras.layers.Layer): ...@@ -428,7 +451,9 @@ class TFBertMLMHead(tf.keras.layers.Layer):
class TFBertNSPHead(tf.keras.layers.Layer): class TFBertNSPHead(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFBertNSPHead, self).__init__(**kwargs) super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship') self.seq_relationship = tf.keras.layers.Dense(2,
kernel_initializer=get_initializer(self.config.initializer_range),
name='seq_relationship')
def call(self, pooled_output): def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
...@@ -454,8 +479,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -454,8 +479,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
# position_ids=None, head_mask=None, training=False):
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
...@@ -819,7 +842,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): ...@@ -819,7 +842,9 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -869,7 +894,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -869,7 +894,9 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier') self.classifier = tf.keras.layers.Dense(1,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -946,7 +973,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): ...@@ -946,7 +973,9 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
...@@ -996,7 +1025,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): ...@@ -996,7 +1025,9 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') self.qa_outputs = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(self.config.initializer_range),
name='qa_outputs')
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.bert(inputs, **kwargs) outputs = self.bert(inputs, **kwargs)
......
...@@ -474,3 +474,12 @@ def shape_list(x): ...@@ -474,3 +474,12 @@ def shape_list(x):
static = x.shape.as_list() static = x.shape.as_list()
dynamic = tf.shape(x) dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)] return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range=0.02):
"""Creates a `tf.initializers.truncated_normal` with the given range.
Args:
initializer_range: float, initializer range for stddev.
Returns:
TruncatedNormal initializer with stddev = `initializer_range`.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment