Unverified Commit 701bd59b authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #585 from huntzhan/master

Make the epsilon of LayerNorm configurable.
parents 303b5e2b 101ab4dd
...@@ -145,7 +145,8 @@ class BertConfig(object): ...@@ -145,7 +145,8 @@ class BertConfig(object):
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02): initializer_range=0.02,
layer_norm_eps=1e-12):
"""Constructs BertConfig. """Constructs BertConfig.
Args: Args:
...@@ -169,6 +170,7 @@ class BertConfig(object): ...@@ -169,6 +170,7 @@ class BertConfig(object):
`BertModel`. `BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -188,6 +190,7 @@ class BertConfig(object): ...@@ -188,6 +190,7 @@ class BertConfig(object):
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -254,7 +257,7 @@ class BertEmbeddings(nn.Module): ...@@ -254,7 +257,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None): def forward(self, input_ids, token_type_ids=None):
...@@ -329,7 +332,7 @@ class BertSelfOutput(nn.Module): ...@@ -329,7 +332,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertSelfOutput, self).__init__() super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -370,7 +373,7 @@ class BertOutput(nn.Module): ...@@ -370,7 +373,7 @@ class BertOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertOutput, self).__init__() super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
...@@ -434,7 +437,7 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -434,7 +437,7 @@ class BertPredictionHeadTransform(nn.Module):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment